Преглед изворни кода

Feat/add retriever rerank (#1560)

Co-authored-by: jyong <jyong@dify.ai>
Jyong пре 1 година
родитељ
комит
4588831bff
44 измењених фајлова са 1903 додато и 168 уклоњено
  1. 46 11
      api/commands.py
  2. 48 0
      api/controllers/console/datasets/datasets.py
  3. 4 0
      api/controllers/console/datasets/datasets_document.py
  4. 5 6
      api/controllers/console/datasets/hit_testing.py
  5. 10 11
      api/controllers/console/workspace/models.py
  6. 4 0
      api/controllers/service_api/dataset/document.py
  7. 0 2
      api/core/agent/agent/multi_dataset_router_agent.py
  8. 158 0
      api/core/agent/agent/output_parser/retirver_dataset_agent.py
  9. 0 1
      api/core/agent/agent/structed_multi_dataset_router_agent.py
  10. 3 2
      api/core/agent/agent_executor.py
  11. 1 3
      api/core/callback_handler/index_tool_callback_handler.py
  12. 1 0
      api/core/completion.py
  13. 28 18
      api/core/data_loader/file_extractor.py
  14. 7 0
      api/core/index/vector_index/base.py
  15. 9 7
      api/core/index/vector_index/milvus_vector_index.py
  16. 18 0
      api/core/index/vector_index/qdrant_vector_index.py
  17. 8 1
      api/core/index/vector_index/weaviate_vector_index.py
  18. 5 5
      api/core/indexing_runner.py
  19. 39 0
      api/core/model_providers/model_factory.py
  20. 3 0
      api/core/model_providers/model_provider_factory.py
  21. 1 1
      api/core/model_providers/models/entity/model_params.py
  22. 0 0
      api/core/model_providers/models/reranking/__init__.py
  23. 36 0
      api/core/model_providers/models/reranking/base.py
  24. 73 0
      api/core/model_providers/models/reranking/cohere_reranking.py
  25. 152 0
      api/core/model_providers/providers/cohere_provider.py
  26. 2 1
      api/core/model_providers/rules/_providers.json
  27. 7 0
      api/core/model_providers/rules/cohere.json
  28. 85 42
      api/core/orchestrator_rule_parser.py
  29. 227 0
      api/core/tool/dataset_multi_retriever_tool.py
  30. 68 19
      api/core/tool/dataset_retriever_tool.py
  31. 1 1
      api/core/vector_store/milvus_vector_store.py
  32. 2 1
      api/core/vector_store/qdrant_vector_store.py
  33. 0 0
      api/core/vector_store/vector/milvus.py
  34. 47 3
      api/core/vector_store/vector/qdrant.py
  35. 505 0
      api/core/vector_store/vector/weaviate.py
  36. 19 1
      api/fields/dataset_fields.py
  37. 43 0
      api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py
  38. 18 4
      api/models/dataset.py
  39. 7 1
      api/models/model.py
  40. 3 2
      api/requirements.txt
  41. 10 1
      api/services/app_model_config_service.py
  42. 33 2
      api/services/dataset_service.py
  43. 79 22
      api/services/hit_testing_service.py
  44. 88 0
      api/services/retrieval_service.py

+ 46 - 11
api/commands.py

@@ -8,6 +8,8 @@ import time
 import uuid
 
 import click
+import qdrant_client
+from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
 from tqdm import tqdm
 from flask import current_app, Flask
 from langchain.embeddings import OpenAIEmbeddings
@@ -484,6 +486,38 @@ def normalization_collections():
     click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
 
 
+@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
+def add_qdrant_full_text_index():
+    click.echo(click.style('Start add full text index.', fg='green'))
+    binds = db.session.query(DatasetCollectionBinding).all()
+    if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
+        qdrant_url = current_app.config['QDRANT_URL']
+        qdrant_api_key = current_app.config['QDRANT_API_KEY']
+        client = qdrant_client.QdrantClient(
+            qdrant_url,
+            api_key=qdrant_api_key,  # For Qdrant Cloud, None for local instance
+        )
+        for bind in binds:
+            try:
+                text_index_params = TextIndexParams(
+                    type=TextIndexType.TEXT,
+                    tokenizer=TokenizerType.MULTILINGUAL,
+                    min_token_len=2,
+                    max_token_len=20,
+                    lowercase=True
+                )
+                client.create_payload_index(bind.collection_name, 'page_content',
+                                            field_schema=text_index_params)
+            except Exception as e:
+                click.echo(
+                    click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
+                                fg='red'))
+            click.echo(
+                click.style(
+                    'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
+                    fg='green'))
+
+
 def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
     with flask_app.app_context():
         try:
@@ -647,10 +681,10 @@ def update_app_model_configs(batch_size):
 
             pbar.update(len(data_batch))
 
+
 @click.command('migrate_default_input_to_dataset_query_variable')
 @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
 def migrate_default_input_to_dataset_query_variable(batch_size):
-
     click.secho("Starting...", fg='green')
 
     total_records = db.session.query(AppModelConfig) \
@@ -658,13 +692,13 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
         .filter(App.mode == 'completion') \
         .filter(AppModelConfig.dataset_query_variable == None) \
         .count()
-    
+
     if total_records == 0:
         click.secho("No data to migrate.", fg='green')
         return
 
     num_batches = (total_records + batch_size - 1) // batch_size
-    
+
     with tqdm(total=total_records, desc="Migrating Data") as pbar:
         for i in range(num_batches):
             offset = i * batch_size
@@ -697,14 +731,14 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
                     for form in user_input_form:
                         paragraph = form.get('paragraph')
                         if paragraph \
-                            and paragraph.get('variable') == 'query':
-                                data.dataset_query_variable = 'query'
-                                break
-                        
+                                and paragraph.get('variable') == 'query':
+                            data.dataset_query_variable = 'query'
+                            break
+
                         if paragraph \
-                            and paragraph.get('variable') == 'default_input':
-                                data.dataset_query_variable = 'default_input'
-                                break
+                                and paragraph.get('variable') == 'default_input':
+                            data.dataset_query_variable = 'default_input'
+                            break
 
                 db.session.commit()
 
@@ -712,7 +746,7 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
                 click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
                             fg='red')
                 continue
-            
+
             click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
 
             pbar.update(len(data_batch))
@@ -731,3 +765,4 @@ def register_commands(app):
     app.cli.add_command(update_app_model_configs)
     app.cli.add_command(normalization_collections)
     app.cli.add_command(migrate_default_input_to_dataset_query_variable)
+    app.cli.add_command(add_qdrant_full_text_index)

+ 48 - 0
api/controllers/console/datasets/datasets.py

@@ -170,6 +170,7 @@ class DatasetApi(Resource):
                             help='Invalid indexing technique.')
         parser.add_argument('permission', type=str, location='json', choices=(
             'only_me', 'all_team_members'), help='Invalid permission.')
+        parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
         args = parser.parse_args()
 
         # The role of the current user in the ta table must be admin or owner
@@ -401,6 +402,7 @@ class DatasetApiKeyApi(Resource):
 
 class DatasetApiDeleteApi(Resource):
     resource_type = 'dataset'
+
     @setup_required
     @login_required
     @account_initialization_required
@@ -436,6 +438,50 @@ class DatasetApiBaseUrlApi(Resource):
         }
 
 
+class DatasetRetrievalSettingApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        vector_type = current_app.config['VECTOR_STORE']
+        if vector_type == 'milvus':
+            return {
+                'retrieval_method': [
+                    'semantic_search'
+                ]
+            }
+        elif vector_type == 'qdrant' or vector_type == 'weaviate':
+            return {
+                'retrieval_method': [
+                    'semantic_search', 'full_text_search', 'hybrid_search'
+                ]
+            }
+        else:
+            raise ValueError("Unsupported vector db type.")
+
+
+class DatasetRetrievalSettingMockApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, vector_type):
+
+        if vector_type == 'milvus':
+            return {
+                'retrieval_method': [
+                    'semantic_search'
+                ]
+            }
+        elif vector_type == 'qdrant' or vector_type == 'weaviate':
+            return {
+                'retrieval_method': [
+                    'semantic_search', 'full_text_search', 'hybrid_search'
+                ]
+            }
+        else:
+            raise ValueError("Unsupported vector db type.")
+
+
 api.add_resource(DatasetListApi, '/datasets')
 api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
 api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
@@ -445,3 +491,5 @@ api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing
 api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
 api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
 api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
+api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
+api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')

+ 4 - 0
api/controllers/console/datasets/datasets_document.py

@@ -221,6 +221,8 @@ class DatasetDocumentListApi(Resource):
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
                             location='json')
+        parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
 
         if not dataset.indexing_technique and not args['indexing_technique']:
@@ -263,6 +265,8 @@ class DatasetInitApi(Resource):
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
                             location='json')
+        parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
         if args['indexing_technique'] == 'high_quality':
             try:

+ 5 - 6
api/controllers/console/datasets/hit_testing.py

@@ -42,19 +42,18 @@ class HitTestingApi(Resource):
 
         parser = reqparse.RequestParser()
         parser.add_argument('query', type=str, location='json')
+        parser.add_argument('retrieval_model', type=dict, required=False, location='json')
         args = parser.parse_args()
 
-        query = args['query']
-
-        if not query or len(query) > 250:
-            raise ValueError('Query is required and cannot exceed 250 characters')
+        HitTestingService.hit_testing_args_check(args)
 
         try:
             response = HitTestingService.retrieve(
                 dataset=dataset,
-                query=query,
+                query=args['query'],
                 account=current_user,
-                limit=10,
+                retrieval_model=args['retrieval_model'],
+                limit=10
             )
 
             return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}

+ 10 - 11
api/controllers/console/workspace/models.py

@@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
     def get(self):
         parser = reqparse.RequestParser()
         parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text'], location='args')
+                            choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
         args = parser.parse_args()
 
         tenant_id = current_user.current_tenant_id
@@ -71,19 +71,18 @@ class DefaultModelApi(Resource):
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('model_type', type=str, required=True, nullable=False,
-                            choices=['text-generation', 'embeddings', 'speech2text'], location='json')
-        parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
         args = parser.parse_args()
 
         provider_service = ProviderService()
-        provider_service.update_default_model_of_model_type(
-            tenant_id=current_user.current_tenant_id,
-            model_type=args['model_type'],
-            provider_name=args['provider_name'],
-            model_name=args['model_name']
-        )
+        model_settings = args['model_settings']
+        for model_setting in model_settings:
+            provider_service.update_default_model_of_model_type(
+                tenant_id=current_user.current_tenant_id,
+                model_type=model_setting['model_type'],
+                provider_name=model_setting['provider_name'],
+                model_name=model_setting['model_name']
+            )
 
         return {'result': 'success'}
 

+ 4 - 0
api/controllers/service_api/dataset/document.py

@@ -36,6 +36,8 @@ class DocumentAddByTextApi(DatasetApiResource):
                             location='json')
         parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
                             location='json')
+        parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)
@@ -95,6 +97,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
                             location='json')
+        parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
         dataset_id = str(dataset_id)
         tenant_id = str(tenant_id)

+ 0 - 2
api/core/agent/agent/multi_dataset_router_agent.py

@@ -14,7 +14,6 @@ from pydantic import root_validator
 from core.model_providers.models.entity.message import to_prompt_messages
 from core.model_providers.models.llm.base import BaseLLM
 from core.third_party.langchain.llms.fake import FakeLLM
-from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
 class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@@ -60,7 +59,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             return AgentFinish(return_values={"output": ''}, log='')
         elif len(self.tools) == 1:
             tool = next(iter(self.tools))
-            tool = cast(DatasetRetrieverTool, tool)
             rst = tool.run(tool_input={'query': kwargs['input']})
             # output = ''
             # rst_json = json.loads(rst)

+ 158 - 0
api/core/agent/agent/output_parser/retirver_dataset_agent.py

@@ -0,0 +1,158 @@
+import json
+from typing import Tuple, List, Any, Union, Sequence, Optional, cast
+
+from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
+from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
+from langchain.callbacks.base import BaseCallbackManager
+from langchain.callbacks.manager import Callbacks
+from langchain.prompts.chat import BaseMessagePromptTemplate
+from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
+from langchain.schema.language_model import BaseLanguageModel
+from langchain.tools import BaseTool
+from pydantic import root_validator
+
+from core.model_providers.models.entity.message import to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
+from core.third_party.langchain.llms.fake import FakeLLM
+from core.tool.dataset_retriever_tool import DatasetRetrieverTool
+
+
+class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
+    """
+    An Multi Dataset Retrieve Agent driven by Router.
+    """
+    model_instance: BaseLLM
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    @root_validator
+    def validate_llm(cls, values: dict) -> dict:
+        return values
+
+    def should_use_agent(self, query: str):
+        """
+        return should use agent
+
+        :param query:
+        :return:
+        """
+        return True
+
+    def plan(
+        self,
+        intermediate_steps: List[Tuple[AgentAction, str]],
+        callbacks: Callbacks = None,
+        **kwargs: Any,
+    ) -> Union[AgentAction, AgentFinish]:
+        """Given input, decided what to do.
+
+        Args:
+            intermediate_steps: Steps the LLM has taken to date, along with observations
+            **kwargs: User inputs.
+
+        Returns:
+            Action specifying what tool to use.
+        """
+        if len(self.tools) == 0:
+            return AgentFinish(return_values={"output": ''}, log='')
+        elif len(self.tools) == 1:
+            tool = next(iter(self.tools))
+            tool = cast(DatasetRetrieverTool, tool)
+            rst = tool.run(tool_input={'query': kwargs['input']})
+            # output = ''
+            # rst_json = json.loads(rst)
+            # for item in rst_json:
+            #     output += f'{item["content"]}\n'
+            return AgentFinish(return_values={"output": rst}, log=rst)
+
+        if intermediate_steps:
+            _, observation = intermediate_steps[-1]
+            return AgentFinish(return_values={"output": observation}, log=observation)
+
+        try:
+            agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
+            if isinstance(agent_decision, AgentAction):
+                tool_inputs = agent_decision.tool_input
+                if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
+                    tool_inputs['query'] = kwargs['input']
+                    agent_decision.tool_input = tool_inputs
+            else:
+                agent_decision.return_values['output'] = ''
+            return agent_decision
+        except Exception as e:
+            new_exception = self.model_instance.handle_exceptions(e)
+            raise new_exception
+
+    def real_plan(
+        self,
+        intermediate_steps: List[Tuple[AgentAction, str]],
+        callbacks: Callbacks = None,
+        **kwargs: Any,
+    ) -> Union[AgentAction, AgentFinish]:
+        """Given input, decided what to do.
+
+        Args:
+            intermediate_steps: Steps the LLM has taken to date, along with observations
+            **kwargs: User inputs.
+
+        Returns:
+            Action specifying what tool to use.
+        """
+        agent_scratchpad = _format_intermediate_steps(intermediate_steps)
+        selected_inputs = {
+            k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
+        }
+        full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
+        prompt = self.prompt.format_prompt(**full_inputs)
+        messages = prompt.to_messages()
+        prompt_messages = to_prompt_messages(messages)
+        result = self.model_instance.run(
+            messages=prompt_messages,
+            functions=self.functions,
+        )
+
+        ai_message = AIMessage(
+            content=result.content,
+            additional_kwargs={
+                'function_call': result.function_call
+            }
+        )
+
+        agent_decision = _parse_ai_message(ai_message)
+        return agent_decision
+
+    async def aplan(
+            self,
+            intermediate_steps: List[Tuple[AgentAction, str]],
+            callbacks: Callbacks = None,
+            **kwargs: Any,
+    ) -> Union[AgentAction, AgentFinish]:
+        raise NotImplementedError()
+
+    @classmethod
+    def from_llm_and_tools(
+            cls,
+            model_instance: BaseLLM,
+            tools: Sequence[BaseTool],
+            callback_manager: Optional[BaseCallbackManager] = None,
+            extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
+            system_message: Optional[SystemMessage] = SystemMessage(
+                content="You are a helpful AI assistant."
+            ),
+            **kwargs: Any,
+    ) -> BaseSingleActionAgent:
+        prompt = cls.create_prompt(
+            extra_prompt_messages=extra_prompt_messages,
+            system_message=system_message,
+        )
+        return cls(
+            model_instance=model_instance,
+            llm=FakeLLM(response=''),
+            prompt=prompt,
+            tools=tools,
+            callback_manager=callback_manager,
+            **kwargs,
+        )

+ 0 - 1
api/core/agent/agent/structed_multi_dataset_router_agent.py

@@ -89,7 +89,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
             return AgentFinish(return_values={"output": ''}, log='')
         elif len(self.dataset_tools) == 1:
             tool = next(iter(self.dataset_tools))
-            tool = cast(DatasetRetrieverTool, tool)
             rst = tool.run(tool_input={'query': kwargs['input']})
             return AgentFinish(return_values={"output": rst}, log=rst)
 

+ 3 - 2
api/core/agent/agent_executor.py

@@ -18,6 +18,7 @@ from langchain.agents import AgentExecutor as LCAgentExecutor
 from core.helper import moderation
 from core.model_providers.error import LLMError
 from core.model_providers.models.llm.base import BaseLLM
+from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 
 
@@ -78,7 +79,7 @@ class AgentExecutor:
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
+            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
             agent = MultiDatasetRouterAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
                 tools=self.configuration.tools,
@@ -86,7 +87,7 @@ class AgentExecutor:
                 verbose=True
             )
         elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
-            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
+            self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)]
             agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
                 model_instance=self.configuration.model_instance,
                 tools=self.configuration.tools,

+ 1 - 3
api/core/callback_handler/index_tool_callback_handler.py

@@ -10,8 +10,7 @@ from models.dataset import DocumentSegment
 class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
 
-    def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
-        self.dataset_id = dataset_id
+    def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
         self.conversation_message_task = conversation_message_task
 
     def on_tool_end(self, documents: List[Document]) -> None:
@@ -21,7 +20,6 @@ class DatasetIndexToolCallbackHandler:
 
             # add hit count to document segment
             db.session.query(DocumentSegment).filter(
-                DocumentSegment.dataset_id == self.dataset_id,
                 DocumentSegment.index_node_id == doc_id
             ).update(
                 {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},

+ 1 - 0
api/core/completion.py

@@ -127,6 +127,7 @@ class Completion:
                 memory=memory,
                 rest_tokens=rest_tokens_for_context_and_memory,
                 chain_callback=chain_callback,
+                tenant_id=app.tenant_id,
                 retriever_from=retriever_from
             )
 

+ 28 - 18
api/core/data_loader/file_extractor.py

@@ -3,7 +3,7 @@ from pathlib import Path
 from typing import List, Union, Optional
 
 import requests
-from langchain.document_loaders import TextLoader, Docx2txtLoader
+from langchain.document_loaders import TextLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredAPIFileLoader
 from langchain.schema import Document
 
 from core.data_loader.loader.csv_loader import CSVLoader
@@ -20,13 +20,13 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
 
 class FileExtractor:
     @classmethod
-    def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
+    def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document] | str]:
         with tempfile.TemporaryDirectory() as temp_dir:
             suffix = Path(upload_file.key).suffix
             file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
             storage.download(upload_file.key, file_path)
 
-            return cls.load_from_file(file_path, return_text, upload_file)
+            return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
 
     @classmethod
     def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document] | str]:
@@ -44,24 +44,34 @@ class FileExtractor:
 
     @classmethod
     def load_from_file(cls, file_path: str, return_text: bool = False,
-                       upload_file: Optional[UploadFile] = None) -> Union[List[Document] | str]:
+                       upload_file: Optional[UploadFile] = None,
+                       is_automatic: bool = False) -> Union[List[Document] | str]:
         input_file = Path(file_path)
         delimiter = '\n'
         file_extension = input_file.suffix.lower()
-        if file_extension == '.xlsx':
-            loader = ExcelLoader(file_path)
-        elif file_extension == '.pdf':
-            loader = PdfLoader(file_path, upload_file=upload_file)
-        elif file_extension in ['.md', '.markdown']:
-            loader = MarkdownLoader(file_path, autodetect_encoding=True)
-        elif file_extension in ['.htm', '.html']:
-            loader = HTMLLoader(file_path)
-        elif file_extension == '.docx':
-            loader = Docx2txtLoader(file_path)
-        elif file_extension == '.csv':
-            loader = CSVLoader(file_path, autodetect_encoding=True)
+        if is_automatic:
+            loader = UnstructuredFileLoader(
+                file_path, strategy="hi_res", mode="elements"
+            )
+            # loader = UnstructuredAPIFileLoader(
+            #     file_path=filenames[0],
+            #     api_key="FAKE_API_KEY",
+            # )
         else:
-            # txt
-            loader = TextLoader(file_path, autodetect_encoding=True)
+            if file_extension == '.xlsx':
+                loader = ExcelLoader(file_path)
+            elif file_extension == '.pdf':
+                loader = PdfLoader(file_path, upload_file=upload_file)
+            elif file_extension in ['.md', '.markdown']:
+                loader = MarkdownLoader(file_path, autodetect_encoding=True)
+            elif file_extension in ['.htm', '.html']:
+                loader = HTMLLoader(file_path)
+            elif file_extension == '.docx':
+                loader = Docx2txtLoader(file_path)
+            elif file_extension == '.csv':
+                loader = CSVLoader(file_path, autodetect_encoding=True)
+            else:
+                # txt
+                loader = TextLoader(file_path, autodetect_encoding=True)
 
         return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

+ 7 - 0
api/core/index/vector_index/base.py

@@ -40,6 +40,13 @@ class BaseVectorIndex(BaseIndex):
     def _get_vector_store_class(self) -> type:
         raise NotImplementedError
 
+    @abstractmethod
+    def search_by_full_text_index(
+            self, query: str,
+            **kwargs: Any
+    ) -> List[Document]:
+        raise NotImplementedError
+
     def search(
             self, query: str,
             **kwargs: Any

+ 9 - 7
api/core/index/vector_index/milvus_vector_index.py

@@ -1,16 +1,14 @@
-from typing import Optional, cast
+from typing import cast, Any, List
 
 from langchain.embeddings.base import Embeddings
-from langchain.schema import Document, BaseRetriever
-from langchain.vectorstores import VectorStore, milvus
+from langchain.schema import Document
+from langchain.vectorstores import VectorStore
 from pydantic import BaseModel, root_validator
 
 from core.index.base import BaseIndex
 from core.index.vector_index.base import BaseVectorIndex
 from core.vector_store.milvus_vector_store import MilvusVectorStore
-from core.vector_store.weaviate_vector_store import WeaviateVectorStore
-from extensions.ext_database import db
-from models.dataset import Dataset, DatasetCollectionBinding
+from models.dataset import Dataset
 
 
 class MilvusConfig(BaseModel):
@@ -74,7 +72,7 @@ class MilvusVectorIndex(BaseVectorIndex):
         index_params = {
             'metric_type': 'IP',
             'index_type': "HNSW",
-            'params':  {"M": 8, "efConstruction": 64}
+            'params': {"M": 8, "efConstruction": 64}
         }
         self._vector_store = MilvusVectorStore.from_documents(
             texts,
@@ -152,3 +150,7 @@ class MilvusVectorIndex(BaseVectorIndex):
                 ),
             ],
         ))
+
+    def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
+        # milvus/zilliz doesn't support bm25 search
+        return []

+ 18 - 0
api/core/index/vector_index/qdrant_vector_index.py

@@ -191,3 +191,21 @@ class QdrantVectorIndex(BaseVectorIndex):
                 return True
 
         return False
+
+    def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        from qdrant_client.http import models
+        return vector_store.similarity_search_by_bm25(models.Filter(
+            must=[
+                models.FieldCondition(
+                    key="group_id",
+                    match=models.MatchValue(value=self.dataset.id),
+                ),
+                models.FieldCondition(
+                    key="page_content",
+                    match=models.MatchText(text=query),
+                )
+            ],
+        ), kwargs.get('top_k', 2))

+ 8 - 1
api/core/index/vector_index/weaviate_vector_index.py

@@ -1,4 +1,4 @@
-from typing import Optional, cast
+from typing import Optional, cast, Any, List
 
 import requests
 import weaviate
@@ -26,6 +26,7 @@ class WeaviateConfig(BaseModel):
 
 
 class WeaviateVectorIndex(BaseVectorIndex):
+
     def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
         super().__init__(dataset, embeddings)
         self._client = self._init_client(config)
@@ -148,3 +149,9 @@ class WeaviateVectorIndex(BaseVectorIndex):
                 return True
 
         return False
+
+    def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+        return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
+

+ 5 - 5
api/core/indexing_runner.py

@@ -49,14 +49,14 @@ class IndexingRunner:
                 if not dataset:
                     raise ValueError("no dataset found")
 
-                # load file
-                text_docs = self._load_data(dataset_document)
-
                 # get the process rule
                 processing_rule = db.session.query(DatasetProcessRule). \
                     filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
                     first()
 
+                # load file
+                text_docs = self._load_data(dataset_document)
+
                 # get splitter
                 splitter = self._get_splitter(processing_rule)
 
@@ -380,7 +380,7 @@ class IndexingRunner:
             "preview": preview_texts
         }
 
-    def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
+    def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:
         # load file
         if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
             return []
@@ -396,7 +396,7 @@ class IndexingRunner:
                 one_or_none()
 
             if file_detail:
-                text_docs = FileExtractor.load(file_detail)
+                text_docs = FileExtractor.load(file_detail, is_automatic=False)
         elif dataset_document.data_source_type == 'notion_import':
             loader = NotionLoader.from_document(dataset_document)
             text_docs = loader.load()

+ 39 - 0
api/core/model_providers/model_factory.py

@@ -9,6 +9,7 @@ from core.model_providers.models.embedding.base import BaseEmbedding
 from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.moderation.base import BaseModeration
+from core.model_providers.models.reranking.base import BaseReranking
 from core.model_providers.models.speech2text.base import BaseSpeech2Text
 from extensions.ext_database import db
 from models.provider import TenantDefaultModel
@@ -140,6 +141,44 @@ class ModelFactory:
             name=model_name
         )
 
+
+    @classmethod
+    def get_reranking_model(cls,
+                            tenant_id: str,
+                            model_provider_name: Optional[str] = None,
+                            model_name: Optional[str] = None) -> Optional[BaseReranking]:
+        """
+        get reranking model.
+
+        :param tenant_id: a string representing the ID of the tenant.
+        :param model_provider_name:
+        :param model_name:
+        :return:
+        """
+        if model_provider_name is None and model_name is None:
+            default_model = cls.get_default_model(tenant_id, ModelType.RERANKING)
+
+            if not default_model:
+                raise LLMBadRequestError(f"Default model is not available. "
+                                         f"Please configure a Default Reranking Model "
+                                         f"in the Settings -> Model Provider.")
+
+            model_provider_name = default_model.provider_name
+            model_name = default_model.model_name
+
+        # get model provider
+        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
+        if not model_provider:
+            raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
+
+        # init reranking model
+        model_class = model_provider.get_model_class(model_type=ModelType.RERANKING)
+        return model_class(
+            model_provider=model_provider,
+            name=model_name
+        )
+
     @classmethod
     def get_speech2text_model(cls,
                               tenant_id: str,

+ 3 - 0
api/core/model_providers/model_provider_factory.py

@@ -72,6 +72,9 @@ class ModelProviderFactory:
         elif provider_name == 'localai':
             from core.model_providers.providers.localai_provider import LocalAIProvider
             return LocalAIProvider
+        elif provider_name == 'cohere':
+            from core.model_providers.providers.cohere_provider import CohereProvider
+            return CohereProvider
         else:
             raise NotImplementedError
 

+ 1 - 1
api/core/model_providers/models/entity/model_params.py

@@ -17,7 +17,7 @@ class ModelType(enum.Enum):
     IMAGE = 'image'
     VIDEO = 'video'
     MODERATION = 'moderation'
-
+    RERANKING = 'reranking'
     @staticmethod
     def value_of(value):
         for member in ModelType:

+ 0 - 0
api/core/model_providers/models/reranking/__init__.py


+ 36 - 0
api/core/model_providers/models/reranking/base.py

@@ -0,0 +1,36 @@
+from abc import abstractmethod
+from typing import Any, Optional, List
+from langchain.schema import Document
+
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.model_params import ModelType
+from core.model_providers.providers.base import BaseModelProvider
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class BaseReranking(BaseProviderModel):
+    name: str
+    type: ModelType = ModelType.RERANKING
+
+    def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
+        super().__init__(model_provider, client)
+        self.name = name
+
+    @property
+    def base_model_name(self) -> str:
+        """
+        get base model name
+        
+        :return: str
+        """
+        return self.name
+
+    @abstractmethod
+    def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        raise NotImplementedError

+ 73 - 0
api/core/model_providers/models/reranking/cohere_reranking.py

@@ -0,0 +1,73 @@
+import logging
+from typing import Optional, List
+
+import cohere
+import openai
+from langchain.schema import Document
+
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, LLMAuthorizationError
+from core.model_providers.models.reranking.base import BaseReranking
+from core.model_providers.providers.base import BaseModelProvider
+
+
+class CohereReranking(BaseReranking):
+
+    def __init__(self, model_provider: BaseModelProvider, name: str):
+        self.credentials = model_provider.get_model_credentials(
+            model_name=name,
+            model_type=self.type
+        )
+
+        client = cohere.Client(self.credentials.get('api_key'))
+
+        super().__init__(model_provider, client, name)
+
+    def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
+        docs = []
+        doc_id = []
+        for document in documents:
+            if document.metadata['doc_id'] not in doc_id:
+                doc_id.append(document.metadata['doc_id'])
+                docs.append(document.page_content)
+        results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
+        rerank_documents = []
+
+        for idx, result in enumerate(results):
+            # format document
+            rerank_document = Document(
+                page_content=result.document['text'],
+                metadata={
+                    "doc_id": documents[result.index].metadata['doc_id'],
+                    "doc_hash": documents[result.index].metadata['doc_hash'],
+                    "document_id": documents[result.index].metadata['document_id'],
+                    "dataset_id": documents[result.index].metadata['dataset_id'],
+                    'score': result.relevance_score
+                }
+            )
+            # score threshold check
+            if score_threshold is not None:
+                if result.relevance_score >= score_threshold:
+                    rerank_documents.append(rerank_document)
+            else:
+                rerank_documents.append(rerank_document)
+        return rerank_documents
+
+    def handle_exceptions(self, ex: Exception) -> Exception:
+        if isinstance(ex, openai.error.InvalidRequestError):
+            logging.warning("Invalid request to OpenAI API.")
+            return LLMBadRequestError(str(ex))
+        elif isinstance(ex, openai.error.APIConnectionError):
+            logging.warning("Failed to connect to OpenAI API.")
+            return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
+            logging.warning("OpenAI service unavailable.")
+            return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
+        elif isinstance(ex, openai.error.RateLimitError):
+            return LLMRateLimitError(str(ex))
+        elif isinstance(ex, openai.error.AuthenticationError):
+            return LLMAuthorizationError(str(ex))
+        elif isinstance(ex, openai.error.OpenAIError):
+            return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
+        else:
+            return ex

+ 152 - 0
api/core/model_providers/providers/cohere_provider.py

@@ -0,0 +1,152 @@
+import json
+from json import JSONDecodeError
+from typing import Type
+
+from langchain.schema import HumanMessage
+
+from core.helper import encrypter
+from core.model_providers.models.base import BaseProviderModel
+from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
+from core.model_providers.models.reranking.cohere_reranking import CohereReranking
+from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
+from models.provider import ProviderType
+
+
+class CohereProvider(BaseModelProvider):
+
+    @property
+    def provider_name(self):
+        """
+        Returns the name of a provider.
+        """
+        return 'cohere'
+    
+    def _get_text_generation_model_mode(self, model_name) -> str:
+        return ModelMode.CHAT.value
+
+    def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
+        if model_type == ModelType.RERANKING:
+            return [
+                {
+                    'id': 'rerank-english-v2.0',
+                    'name': 'rerank-english-v2.0'
+                },
+                {
+                    'id': 'rerank-multilingual-v2.0',
+                    'name': 'rerank-multilingual-v2.0'
+                }
+            ]
+        else:
+            return []
+
+    def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
+        """
+        Returns the model class.
+
+        :param model_type:
+        :return:
+        """
+        if model_type == ModelType.RERANKING:
+            model_class = CohereReranking
+        else:
+            raise NotImplementedError
+
+        return model_class
+
+    def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
+        """
+        get model parameter rules.
+
+        :param model_name:
+        :param model_type:
+        :return:
+        """
+        return ModelKwargsRules(
+            temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
+            top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
+            presence_penalty=KwargRule[float](enabled=False),
+            frequency_penalty=KwargRule[float](enabled=False),
+            max_tokens=KwargRule[int](enabled=False),
+        )
+
+    @classmethod
+    def is_provider_credentials_valid_or_raise(cls, credentials: dict):
+        """
+        Validates the given credentials.
+        """
+        if 'api_key' not in credentials:
+            raise CredentialsValidateFailedError('Cohere api_key must be provided.')
+
+        try:
+            credential_kwargs = {
+                'api_key': credentials['api_key'],
+            }
+            # todo validate
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @classmethod
+    def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
+        credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
+        return credentials
+
+    def get_provider_credentials(self, obfuscated: bool = False) -> dict:
+        if self.provider.provider_type == ProviderType.CUSTOM.value:
+            try:
+                credentials = json.loads(self.provider.encrypted_config)
+            except JSONDecodeError:
+                credentials = {
+                    'api_key': None,
+                }
+
+            if credentials['api_key']:
+                credentials['api_key'] = encrypter.decrypt_token(
+                    self.provider.tenant_id,
+                    credentials['api_key']
+                )
+
+                if obfuscated:
+                    credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
+
+            return credentials
+        else:
+            return {}
+
+    def should_deduct_quota(self):
+        return True
+
+    @classmethod
+    def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
+        """
+        check model credentials valid.
+
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        """
+        return
+
+    @classmethod
+    def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
+                                  credentials: dict) -> dict:
+        """
+        encrypt model credentials for save.
+
+        :param tenant_id:
+        :param model_name:
+        :param model_type:
+        :param credentials:
+        :return:
+        """
+        return {}
+
+    def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
+        """
+        get credentials for llm use.
+
+        :param model_name:
+        :param model_type:
+        :param obfuscated:
+        :return:
+        """
+        return self.get_provider_credentials(obfuscated)

+ 2 - 1
api/core/model_providers/rules/_providers.json

@@ -13,5 +13,6 @@
   "huggingface_hub",
   "xinference",
   "openllm",
-  "localai"
+  "localai",
+  "cohere"
 ]

+ 7 - 0
api/core/model_providers/rules/cohere.json

@@ -0,0 +1,7 @@
+{
+    "support_provider_types": [
+        "custom"
+    ],
+    "system_config": null,
+    "model_flexibility": "fixed"
+}

+ 85 - 42
api/core/orchestrator_rule_parser.py

@@ -1,11 +1,17 @@
-from typing import Optional
+import json
+import threading
+from typing import Optional, List
 
+from flask import Flask
 from langchain import WikipediaAPIWrapper
 from langchain.callbacks.manager import Callbacks
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.tools import BaseTool, Tool, WikipediaQueryRun
 from pydantic import BaseModel, Field
 
+from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
+from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
+from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
 from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
 from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
@@ -17,6 +23,7 @@ from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
 from core.model_providers.models.llm.base import BaseLLM
 from core.tool.current_datetime_tool import DatetimeTool
+from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
 from core.tool.dataset_retriever_tool import DatasetRetrieverTool
 from core.tool.provider.serpapi_provider import SerpAPIToolProvider
 from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
@@ -25,6 +32,16 @@ from extensions.ext_database import db
 from models.dataset import Dataset, DatasetProcessRule
 from models.model import AppModelConfig
 
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enable': False
+}
 
 class OrchestratorRuleParser:
     """Parse the orchestrator rule to entities."""
@@ -34,7 +51,7 @@ class OrchestratorRuleParser:
         self.app_model_config = app_model_config
 
     def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
-                          rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
+                          rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str,
                           retriever_from: str = 'dev') -> Optional[AgentExecutor]:
         if not self.app_model_config.agent_mode_dict:
             return None
@@ -101,7 +118,8 @@ class OrchestratorRuleParser:
                 rest_tokens=rest_tokens,
                 return_resource=return_resource,
                 retriever_from=retriever_from,
-                dataset_configs=dataset_configs
+                dataset_configs=dataset_configs,
+                tenant_id=tenant_id
             )
 
             if len(tools) == 0:
@@ -123,7 +141,7 @@ class OrchestratorRuleParser:
 
         return chain
 
-    def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
+    def to_tools(self, tool_configs: list, callbacks: Callbacks = None,  **kwargs) -> list[BaseTool]:
         """
         Convert app agent tool configs to tools
 
@@ -132,6 +150,7 @@ class OrchestratorRuleParser:
         :return:
         """
         tools = []
+        dataset_tools = []
         for tool_config in tool_configs:
             tool_type = list(tool_config.keys())[0]
             tool_val = list(tool_config.values())[0]
@@ -140,7 +159,7 @@ class OrchestratorRuleParser:
 
             tool = None
             if tool_type == "dataset":
-                tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
+                dataset_tools.append(tool_config)
             elif tool_type == "web_reader":
                 tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
             elif tool_type == "google_search":
@@ -156,57 +175,81 @@ class OrchestratorRuleParser:
                 else:
                     tool.callbacks = callbacks
                 tools.append(tool)
-
+        # format dataset tool
+        if len(dataset_tools) > 0:
+            dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs)
+            if dataset_retriever_tools:
+                tools.extend(dataset_retriever_tools)
         return tools
 
-    def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
-                                  dataset_configs: dict, rest_tokens: int,
+    def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask,
                                   return_resource: bool = False, retriever_from: str = 'dev',
                                   **kwargs) \
-            -> Optional[BaseTool]:
+            -> Optional[List[BaseTool]]:
         """
         A dataset tool is a tool that can be used to retrieve information from a dataset
-        :param rest_tokens:
-        :param tool_config:
-        :param dataset_configs:
+        :param tool_configs:
         :param conversation_message_task:
         :param return_resource:
         :param retriever_from:
         :return:
         """
-        # get dataset from dataset id
-        dataset = db.session.query(Dataset).filter(
-            Dataset.tenant_id == self.tenant_id,
-            Dataset.id == tool_config.get("id")
-        ).first()
-
-        if not dataset:
-            return None
-
-        if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
-            return None
-
-        top_k = dataset_configs.get("top_k", 2)
-
-        # dynamically adjust top_k when the remaining token number is not enough to support top_k
-        top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
+        dataset_configs = kwargs['dataset_configs']
+        retrieval_model = dataset_configs.get('retrieval_model', 'single')
+        tools = []
+        dataset_ids = []
+        tenant_id = None
+        for tool_config in tool_configs:
+            # get dataset from dataset id
+            dataset = db.session.query(Dataset).filter(
+                Dataset.tenant_id == self.tenant_id,
+                Dataset.id == tool_config.get('dataset').get("id")
+            ).first()
 
-        score_threshold = None
-        score_threshold_config = dataset_configs.get("score_threshold")
-        if score_threshold_config and score_threshold_config.get("enable"):
-            score_threshold = score_threshold_config.get("value")
+            if not dataset:
+                return None
 
-        tool = DatasetRetrieverTool.from_dataset(
-            dataset=dataset,
-            top_k=top_k,
-            score_threshold=score_threshold,
-            callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
-            conversation_message_task=conversation_message_task,
-            return_resource=return_resource,
-            retriever_from=retriever_from
-        )
+            if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
+                return None
+            dataset_ids.append(dataset.id)
+            if retrieval_model == 'single':
+                retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
+                top_k = retrieval_model['top_k']
+
+                # dynamically adjust top_k when the remaining token number is not enough to support top_k
+                # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
+
+                score_threshold = None
+                score_threshold_enable = retrieval_model.get("score_threshold_enable")
+                if score_threshold_enable:
+                    score_threshold = retrieval_model.get("score_threshold")
+
+                tool = DatasetRetrieverTool.from_dataset(
+                    dataset=dataset,
+                    top_k=top_k,
+                    score_threshold=score_threshold,
+                    callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
+                    conversation_message_task=conversation_message_task,
+                    return_resource=return_resource,
+                    retriever_from=retriever_from
+                )
+                tools.append(tool)
+        if retrieval_model == 'multiple':
+            tool = DatasetMultiRetrieverTool.from_dataset(
+                dataset_ids=dataset_ids,
+                tenant_id=kwargs['tenant_id'],
+                top_k=dataset_configs.get('top_k', 2),
+                score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
+                callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
+                conversation_message_task=conversation_message_task,
+                return_resource=return_resource,
+                retriever_from=retriever_from,
+                reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'),
+                reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name')
+            )
+            tools.append(tool)
 
-        return tool
+        return tools
 
     def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
         """

+ 227 - 0
api/core/tool/dataset_multi_retriever_tool.py

@@ -0,0 +1,227 @@
+import json
+import threading
+from typing import Type, Optional, List
+
+from flask import current_app, Flask
+from langchain.tools import BaseTool
+from pydantic import Field, BaseModel
+
+from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.conversation_message_task import ConversationMessageTask
+from core.embedding.cached_embedding import CacheEmbedding
+from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
+from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
+from core.model_providers.model_factory import ModelFactory
+from extensions.ext_database import db
+from models.dataset import Dataset, DocumentSegment, Document
+from services.retrieval_service import RetrievalService
+
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enable': False
+}
+
+
+class DatasetMultiRetrieverToolInput(BaseModel):
+    query: str = Field(..., description="dataset multi retriever and rerank")
+
+
+class DatasetMultiRetrieverTool(BaseTool):
+    """Tool for querying multi dataset."""
+    name: str = "dataset-"
+    args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput
+    description: str = "dataset multi retriever and rerank. "
+    tenant_id: str
+    dataset_ids: List[str]
+    top_k: int = 2
+    score_threshold: Optional[float] = None
+    reranking_provider_name: str
+    reranking_model_name: str
+    conversation_message_task: ConversationMessageTask
+    return_resource: bool
+    retriever_from: str
+
+    @classmethod
+    def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs):
+        return cls(
+            name=f'dataset-{tenant_id}',
+            tenant_id=tenant_id,
+            dataset_ids=dataset_ids,
+            **kwargs
+        )
+
+    def _run(self, query: str) -> str:
+        threads = []
+        all_documents = []
+        for dataset_id in self.dataset_ids:
+            retrieval_thread = threading.Thread(target=self._retriever, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset_id': dataset_id,
+                'query': query,
+                'all_documents': all_documents
+            })
+            threads.append(retrieval_thread)
+            retrieval_thread.start()
+        for thread in threads:
+            thread.join()
+        # do rerank for searched documents
+        rerank = ModelFactory.get_reranking_model(
+            tenant_id=self.tenant_id,
+            model_provider_name=self.reranking_provider_name,
+            model_name=self.reranking_model_name
+        )
+        all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k)
+
+        hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
+        hit_callback.on_tool_end(all_documents)
+
+        document_context_list = []
+        index_node_ids = [document.metadata['doc_id'] for document in all_documents]
+        segments = DocumentSegment.query.filter(
+            DocumentSegment.completed_at.isnot(None),
+            DocumentSegment.status == 'completed',
+            DocumentSegment.enabled == True,
+            DocumentSegment.index_node_id.in_(index_node_ids)
+        ).all()
+
+        if segments:
+            index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
+            sorted_segments = sorted(segments,
+                                     key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
+                                                                                       float('inf')))
+            for segment in sorted_segments:
+                if segment.answer:
+                    document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
+                else:
+                    document_context_list.append(segment.content)
+            if self.return_resource:
+                context_list = []
+                resource_number = 1
+                for segment in sorted_segments:
+                    dataset = Dataset.query.filter_by(
+                        id=segment.dataset_id
+                    ).first()
+                    document = Document.query.filter(Document.id == segment.document_id,
+                                                     Document.enabled == True,
+                                                     Document.archived == False,
+                                                     ).first()
+                    if dataset and document:
+                        source = {
+                            'position': resource_number,
+                            'dataset_id': dataset.id,
+                            'dataset_name': dataset.name,
+                            'document_id': document.id,
+                            'document_name': document.name,
+                            'data_source_type': document.data_source_type,
+                            'segment_id': segment.id,
+                            'retriever_from': self.retriever_from
+                        }
+                        if self.retriever_from == 'dev':
+                            source['hit_count'] = segment.hit_count
+                            source['word_count'] = segment.word_count
+                            source['segment_position'] = segment.position
+                            source['index_node_hash'] = segment.index_node_hash
+                        if segment.answer:
+                            source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
+                        else:
+                            source['content'] = segment.content
+                        context_list.append(source)
+                    resource_number += 1
+                hit_callback.return_retriever_resource_info(context_list)
+
+            return str("\n".join(document_context_list))
+
+    async def _arun(self, tool_input: str) -> str:
+        raise NotImplementedError()
+
+    def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List):
+        with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.tenant_id == self.tenant_id,
+                Dataset.id == dataset_id
+            ).first()
+
+            if not dataset:
+                return []
+            # get retrieval model , if the model is not setting , using default
+            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
+
+            if dataset.indexing_technique == "economy":
+                # use keyword table query
+                kw_table_index = KeywordTableIndex(
+                    dataset=dataset,
+                    config=KeywordTableConfig(
+                        max_keywords_per_chunk=5
+                    )
+                )
+
+                documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
+                if documents:
+                    all_documents.extend(documents)
+            else:
+
+                try:
+                    embedding_model = ModelFactory.get_embedding_model(
+                        tenant_id=dataset.tenant_id,
+                        model_provider_name=dataset.embedding_model_provider,
+                        model_name=dataset.embedding_model
+                    )
+                except LLMBadRequestError:
+                    return []
+                except ProviderTokenNotInitError:
+                    return []
+
+                embeddings = CacheEmbedding(embedding_model)
+
+                documents = []
+                threads = []
+                if self.top_k > 0:
+                    # retrieval_model source with semantic
+                    if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
+                        'search_method'] == 'hybrid_search':
+                        embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
+                            'flask_app': current_app._get_current_object(),
+                            'dataset': dataset,
+                            'query': query,
+                            'top_k': self.top_k,
+                            'score_threshold': self.score_threshold,
+                            'reranking_model': None,
+                            'all_documents': documents,
+                            'search_method': 'hybrid_search',
+                            'embeddings': embeddings
+                        })
+                        threads.append(embedding_thread)
+                        embedding_thread.start()
+
+                    # retrieval_model source with full text
+                    if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
+                        'search_method'] == 'hybrid_search':
+                        full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
+                                                                  kwargs={
+                                                                      'flask_app': current_app._get_current_object(),
+                                                                      'dataset': dataset,
+                                                                      'query': query,
+                                                                      'search_method': 'hybrid_search',
+                                                                      'embeddings': embeddings,
+                                                                      'score_threshold': retrieval_model[
+                                                                          'score_threshold'] if retrieval_model[
+                                                                          'score_threshold_enable'] else None,
+                                                                      'top_k': self.top_k,
+                                                                      'reranking_model': retrieval_model[
+                                                                          'reranking_model'] if retrieval_model[
+                                                                          'reranking_enable'] else None,
+                                                                      'all_documents': documents
+                                                                  })
+                        threads.append(full_text_index_thread)
+                        full_text_index_thread.start()
+
+                    for thread in threads:
+                        thread.join()
+
+                    all_documents.extend(documents)

+ 68 - 19
api/core/tool/dataset_retriever_tool.py

@@ -1,5 +1,6 @@
 import json
-from typing import Type, Optional
+import threading
+from typing import Type, Optional, List
 
 from flask import current_app
 from langchain.tools import BaseTool
@@ -14,6 +15,18 @@ from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitE
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
 from models.dataset import Dataset, DocumentSegment, Document
+from services.retrieval_service import RetrievalService
+
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enable': False
+}
 
 
 class DatasetRetrieverToolInput(BaseModel):
@@ -56,7 +69,9 @@ class DatasetRetrieverTool(BaseTool):
         ).first()
 
         if not dataset:
-            return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
+            return ''
+        # get retrieval model , if the model is not setting , using default
+        retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
 
         if dataset.indexing_technique == "economy":
             # use keyword table query
@@ -83,28 +98,62 @@ class DatasetRetrieverTool(BaseTool):
                 return ''
 
             embeddings = CacheEmbedding(embedding_model)
-            vector_index = VectorIndex(
-                dataset=dataset,
-                config=current_app.config,
-                embeddings=embeddings
-            )
 
+            documents = []
+            threads = []
             if self.top_k > 0:
-                documents = vector_index.search(
-                    query,
-                    search_type='similarity_score_threshold',
-                    search_kwargs={
-                        'k': self.top_k,
-                        'score_threshold': self.score_threshold,
-                        'filter': {
-                            'group_id': [dataset.id]
-                        }
-                    }
-                )
+                # retrieval source with semantic
+                if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
+                    embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
+                        'flask_app': current_app._get_current_object(),
+                        'dataset': dataset,
+                        'query': query,
+                        'top_k': self.top_k,
+                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
+                            'score_threshold_enable'] else None,
+                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
+                            'reranking_enable'] else None,
+                        'all_documents': documents,
+                        'search_method': retrieval_model['search_method'],
+                        'embeddings': embeddings
+                    })
+                    threads.append(embedding_thread)
+                    embedding_thread.start()
+
+                # retrieval_model source with full text
+                if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
+                    full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
+                        'flask_app': current_app._get_current_object(),
+                        'dataset': dataset,
+                        'query': query,
+                        'search_method': retrieval_model['search_method'],
+                        'embeddings': embeddings,
+                        'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
+                            'score_threshold_enable'] else None,
+                        'top_k': self.top_k,
+                        'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
+                            'reranking_enable'] else None,
+                        'all_documents': documents
+                    })
+                    threads.append(full_text_index_thread)
+                    full_text_index_thread.start()
+
+                for thread in threads:
+                    thread.join()
+                # hybrid search: rerank after all documents have been searched
+                if retrieval_model['search_method'] == 'hybrid_search':
+                    hybrid_rerank = ModelFactory.get_reranking_model(
+                        tenant_id=dataset.tenant_id,
+                        model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
+                        model_name=retrieval_model['reranking_model']['reranking_model_name']
+                    )
+                    documents = hybrid_rerank.rerank(query, documents,
+                                                     retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                                                     self.top_k)
             else:
                 documents = []
 
-            hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
+            hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
             hit_callback.on_tool_end(documents)
             document_score_list = {}
             if dataset.indexing_technique != "economy":

+ 1 - 1
api/core/vector_store/milvus_vector_store.py

@@ -1,4 +1,4 @@
-from core.index.vector_index.milvus import Milvus
+from core.vector_store.vector.milvus import Milvus
 
 
 class MilvusVectorStore(Milvus):

+ 2 - 1
api/core/vector_store/qdrant_vector_store.py

@@ -4,7 +4,7 @@ from langchain.schema import Document
 from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
 from qdrant_client.local.qdrant_local import QdrantLocal
 
-from core.index.vector_index.qdrant import Qdrant
+from core.vector_store.vector.qdrant import Qdrant
 
 
 class QdrantVectorStore(Qdrant):
@@ -73,3 +73,4 @@ class QdrantVectorStore(Qdrant):
         if isinstance(self.client, QdrantLocal):
             self.client = cast(QdrantLocal, self.client)
             self.client._load()
+

+ 0 - 0
api/core/index/vector_index/milvus.py → api/core/vector_store/vector/milvus.py


+ 47 - 3
api/core/index/vector_index/qdrant.py → api/core/vector_store/vector/qdrant.py

@@ -28,7 +28,7 @@ from langchain.docstore.document import Document
 from langchain.embeddings.base import Embeddings
 from langchain.vectorstores import VectorStore
 from langchain.vectorstores.utils import maximal_marginal_relevance
-from qdrant_client.http.models import PayloadSchemaType
+from qdrant_client.http.models import PayloadSchemaType, FilterSelector, TextIndexParams, TokenizerType, TextIndexType
 
 if TYPE_CHECKING:
     from qdrant_client import grpc  # noqa
@@ -189,14 +189,25 @@ class Qdrant(VectorStore):
             texts, metadatas, ids, batch_size
         ):
             self.client.upsert(
-                collection_name=self.collection_name, points=points, **kwargs
+                collection_name=self.collection_name, points=points
             )
             added_ids.extend(batch_ids)
         # if is new collection, create payload index on group_id
         if self.is_new_collection:
+            # create payload index
             self.client.create_payload_index(self.collection_name, self.group_payload_key,
                                              field_schema=PayloadSchemaType.KEYWORD,
                                              field_type=PayloadSchemaType.KEYWORD)
+            # creat full text index
+            text_index_params = TextIndexParams(
+                type=TextIndexType.TEXT,
+                tokenizer=TokenizerType.MULTILINGUAL,
+                min_token_len=2,
+                max_token_len=20,
+                lowercase=True
+            )
+            self.client.create_payload_index(self.collection_name, self.content_payload_key,
+                                             field_schema=text_index_params)
         return added_ids
 
     @sync_call_fallback
@@ -600,7 +611,7 @@ class Qdrant(VectorStore):
             limit=k,
             offset=offset,
             with_payload=True,
-            with_vectors=True,  # Langchain does not expect vectors to be returned
+            with_vectors=True,
             score_threshold=score_threshold,
             consistency=consistency,
             **kwargs,
@@ -615,6 +626,39 @@ class Qdrant(VectorStore):
             for result in results
         ]
 
+    def similarity_search_by_bm25(
+        self,
+        filter: Optional[MetadataFilter] = None,
+        k: int = 4
+    ) -> List[Document]:
+        """Return docs most similar by bm25.
+
+        Args:
+            embedding: Embedding vector to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+            filter: Filter by metadata. Defaults to None.
+            search_params: Additional search params
+        Returns:
+            List of documents most similar to the query text and distance for each.
+        """
+        response = self.client.scroll(
+            collection_name=self.collection_name,
+            scroll_filter=filter,
+            limit=k,
+            with_payload=True,
+            with_vectors=True
+
+        )
+        results = response[0]
+        documents = []
+        for result in results:
+            if result:
+                documents.append(self._document_from_scored_point(
+                        result, self.content_payload_key, self.metadata_payload_key
+                    ))
+
+        return documents
+
     @sync_call_fallback
     async def asimilarity_search_with_score_by_vector(
         self,

+ 505 - 0
api/core/vector_store/vector/weaviate.py

@@ -0,0 +1,505 @@
+"""Wrapper around weaviate vector database."""
+from __future__ import annotations
+
+import datetime
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
+from uuid import uuid4
+
+import numpy as np
+
+from langchain.docstore.document import Document
+from langchain.embeddings.base import Embeddings
+from langchain.utils import get_from_dict_or_env
+from langchain.vectorstores.base import VectorStore
+from langchain.vectorstores.utils import maximal_marginal_relevance
+
+
+def _default_schema(index_name: str) -> Dict:
+    return {
+        "class": index_name,
+        "properties": [
+            {
+                "name": "text",
+                "dataType": ["text"],
+            }
+        ],
+    }
+
+
+def _create_weaviate_client(**kwargs: Any) -> Any:
+    client = kwargs.get("client")
+    if client is not None:
+        return client
+
+    weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
+
+    try:
+        # the weaviate api key param should not be mandatory
+        weaviate_api_key = get_from_dict_or_env(
+            kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
+        )
+    except ValueError:
+        weaviate_api_key = None
+
+    try:
+        import weaviate
+    except ImportError:
+        raise ValueError(
+            "Could not import weaviate python  package. "
+            "Please install it with `pip install weaviate-client`"
+        )
+
+    auth = (
+        weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
+        if weaviate_api_key is not None
+        else None
+    )
+    client = weaviate.Client(weaviate_url, auth_client_secret=auth)
+
+    return client
+
+
+def _default_score_normalizer(val: float) -> float:
+    return 1 - 1 / (1 + np.exp(val))
+
+
+def _json_serializable(value: Any) -> Any:
+    if isinstance(value, datetime.datetime):
+        return value.isoformat()
+    return value
+
+
+class Weaviate(VectorStore):
+    """Wrapper around Weaviate vector database.
+
+    To use, you should have the ``weaviate-client`` python package installed.
+
+    Example:
+        .. code-block:: python
+
+            import weaviate
+            from langchain.vectorstores import Weaviate
+            client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
+            weaviate = Weaviate(client, index_name, text_key)
+
+    """
+
+    def __init__(
+        self,
+        client: Any,
+        index_name: str,
+        text_key: str,
+        embedding: Optional[Embeddings] = None,
+        attributes: Optional[List[str]] = None,
+        relevance_score_fn: Optional[
+            Callable[[float], float]
+        ] = _default_score_normalizer,
+        by_text: bool = True,
+    ):
+        """Initialize with Weaviate client."""
+        try:
+            import weaviate
+        except ImportError:
+            raise ValueError(
+                "Could not import weaviate python package. "
+                "Please install it with `pip install weaviate-client`."
+            )
+        if not isinstance(client, weaviate.Client):
+            raise ValueError(
+                f"client should be an instance of weaviate.Client, got {type(client)}"
+            )
+        self._client = client
+        self._index_name = index_name
+        self._embedding = embedding
+        self._text_key = text_key
+        self._query_attrs = [self._text_key]
+        self.relevance_score_fn = relevance_score_fn
+        self._by_text = by_text
+        if attributes is not None:
+            self._query_attrs.extend(attributes)
+
+    @property
+    def embeddings(self) -> Optional[Embeddings]:
+        return self._embedding
+
+    def _select_relevance_score_fn(self) -> Callable[[float], float]:
+        return (
+            self.relevance_score_fn
+            if self.relevance_score_fn
+            else _default_score_normalizer
+        )
+
+    def add_texts(
+        self,
+        texts: Iterable[str],
+        metadatas: Optional[List[dict]] = None,
+        **kwargs: Any,
+    ) -> List[str]:
+        """Upload texts with metadata (properties) to Weaviate."""
+        from weaviate.util import get_valid_uuid
+
+        ids = []
+        embeddings: Optional[List[List[float]]] = None
+        if self._embedding:
+            if not isinstance(texts, list):
+                texts = list(texts)
+            embeddings = self._embedding.embed_documents(texts)
+
+        with self._client.batch as batch:
+            for i, text in enumerate(texts):
+                data_properties = {self._text_key: text}
+                if metadatas is not None:
+                    for key, val in metadatas[i].items():
+                        data_properties[key] = _json_serializable(val)
+
+                # Allow for ids (consistent w/ other methods)
+                # # Or uuids (backwards compatble w/ existing arg)
+                # If the UUID of one of the objects already exists
+                # then the existing object will be replaced by the new object.
+                _id = get_valid_uuid(uuid4())
+                if "uuids" in kwargs:
+                    _id = kwargs["uuids"][i]
+                elif "ids" in kwargs:
+                    _id = kwargs["ids"][i]
+
+                batch.add_data_object(
+                    data_object=data_properties,
+                    class_name=self._index_name,
+                    uuid=_id,
+                    vector=embeddings[i] if embeddings else None,
+                )
+                ids.append(_id)
+        return ids
+
+    def similarity_search(
+        self, query: str, k: int = 4, **kwargs: Any
+    ) -> List[Document]:
+        """Return docs most similar to query.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+
+        Returns:
+            List of Documents most similar to the query.
+        """
+        if self._by_text:
+            return self.similarity_search_by_text(query, k, **kwargs)
+        else:
+            if self._embedding is None:
+                raise ValueError(
+                    "_embedding cannot be None for similarity_search when "
+                    "_by_text=False"
+                )
+            embedding = self._embedding.embed_query(query)
+            return self.similarity_search_by_vector(embedding, k, **kwargs)
+
+    def similarity_search_by_text(
+        self, query: str, k: int = 4, **kwargs: Any
+    ) -> List[Document]:
+        """Return docs most similar to query.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+
+        Returns:
+            List of Documents most similar to the query.
+        """
+        content: Dict[str, Any] = {"concepts": [query]}
+        if kwargs.get("search_distance"):
+            content["certainty"] = kwargs.get("search_distance")
+        query_obj = self._client.query.get(self._index_name, self._query_attrs)
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        if kwargs.get("additional"):
+            query_obj = query_obj.with_additional(kwargs.get("additional"))
+        result = query_obj.with_near_text(content).with_limit(k).do()
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+        docs = []
+        for res in result["data"]["Get"][self._index_name]:
+            text = res.pop(self._text_key)
+            docs.append(Document(page_content=text, metadata=res))
+        return docs
+
+    def similarity_search_by_bm25(
+        self, query: str, k: int = 4, **kwargs: Any
+    ) -> List[Document]:
+        """Return docs using BM25F.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+
+        Returns:
+            List of Documents most similar to the query.
+        """
+        content: Dict[str, Any] = {"concepts": [query]}
+        if kwargs.get("search_distance"):
+            content["certainty"] = kwargs.get("search_distance")
+        query_obj = self._client.query.get(self._index_name, self._query_attrs)
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        if kwargs.get("additional"):
+            query_obj = query_obj.with_additional(kwargs.get("additional"))
+        result = query_obj.with_bm25(query=content).with_limit(k).do()
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+        docs = []
+        for res in result["data"]["Get"][self._index_name]:
+            text = res.pop(self._text_key)
+            docs.append(Document(page_content=text, metadata=res))
+        return docs
+
+    def similarity_search_by_vector(
+        self, embedding: List[float], k: int = 4, **kwargs: Any
+    ) -> List[Document]:
+        """Look up similar documents by embedding vector in Weaviate."""
+        vector = {"vector": embedding}
+        query_obj = self._client.query.get(self._index_name, self._query_attrs)
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        if kwargs.get("additional"):
+            query_obj = query_obj.with_additional(kwargs.get("additional"))
+        result = query_obj.with_near_vector(vector).with_limit(k).do()
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+        docs = []
+        for res in result["data"]["Get"][self._index_name]:
+            text = res.pop(self._text_key)
+            docs.append(Document(page_content=text, metadata=res))
+        return docs
+
+    def max_marginal_relevance_search(
+        self,
+        query: str,
+        k: int = 4,
+        fetch_k: int = 20,
+        lambda_mult: float = 0.5,
+        **kwargs: Any,
+    ) -> List[Document]:
+        """Return docs selected using the maximal marginal relevance.
+
+        Maximal marginal relevance optimizes for similarity to query AND diversity
+        among selected documents.
+
+        Args:
+            query: Text to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+            lambda_mult: Number between 0 and 1 that determines the degree
+                        of diversity among the results with 0 corresponding
+                        to maximum diversity and 1 to minimum diversity.
+                        Defaults to 0.5.
+
+        Returns:
+            List of Documents selected by maximal marginal relevance.
+        """
+        if self._embedding is not None:
+            embedding = self._embedding.embed_query(query)
+        else:
+            raise ValueError(
+                "max_marginal_relevance_search requires a suitable Embeddings object"
+            )
+
+        return self.max_marginal_relevance_search_by_vector(
+            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
+        )
+
+    def max_marginal_relevance_search_by_vector(
+        self,
+        embedding: List[float],
+        k: int = 4,
+        fetch_k: int = 20,
+        lambda_mult: float = 0.5,
+        **kwargs: Any,
+    ) -> List[Document]:
+        """Return docs selected using the maximal marginal relevance.
+
+        Maximal marginal relevance optimizes for similarity to query AND diversity
+        among selected documents.
+
+        Args:
+            embedding: Embedding to look up documents similar to.
+            k: Number of Documents to return. Defaults to 4.
+            fetch_k: Number of Documents to fetch to pass to MMR algorithm.
+            lambda_mult: Number between 0 and 1 that determines the degree
+                        of diversity among the results with 0 corresponding
+                        to maximum diversity and 1 to minimum diversity.
+                        Defaults to 0.5.
+
+        Returns:
+            List of Documents selected by maximal marginal relevance.
+        """
+        vector = {"vector": embedding}
+        query_obj = self._client.query.get(self._index_name, self._query_attrs)
+        if kwargs.get("where_filter"):
+            query_obj = query_obj.with_where(kwargs.get("where_filter"))
+        results = (
+            query_obj.with_additional("vector")
+            .with_near_vector(vector)
+            .with_limit(fetch_k)
+            .do()
+        )
+
+        payload = results["data"]["Get"][self._index_name]
+        embeddings = [result["_additional"]["vector"] for result in payload]
+        mmr_selected = maximal_marginal_relevance(
+            np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
+        )
+
+        docs = []
+        for idx in mmr_selected:
+            text = payload[idx].pop(self._text_key)
+            payload[idx].pop("_additional")
+            meta = payload[idx]
+            docs.append(Document(page_content=text, metadata=meta))
+        return docs
+
+    def similarity_search_with_score(
+        self, query: str, k: int = 4, **kwargs: Any
+    ) -> List[Tuple[Document, float]]:
+        """
+        Return list of documents most similar to the query
+        text and cosine distance in float for each.
+        Lower score represents more similarity.
+        """
+        if self._embedding is None:
+            raise ValueError(
+                "_embedding cannot be None for similarity_search_with_score"
+            )
+        content: Dict[str, Any] = {"concepts": [query]}
+        if kwargs.get("search_distance"):
+            content["certainty"] = kwargs.get("search_distance")
+        query_obj = self._client.query.get(self._index_name, self._query_attrs)
+
+        embedded_query = self._embedding.embed_query(query)
+        if not self._by_text:
+            vector = {"vector": embedded_query}
+            result = (
+                query_obj.with_near_vector(vector)
+                .with_limit(k)
+                .with_additional("vector")
+                .do()
+            )
+        else:
+            result = (
+                query_obj.with_near_text(content)
+                .with_limit(k)
+                .with_additional("vector")
+                .do()
+            )
+
+        if "errors" in result:
+            raise ValueError(f"Error during query: {result['errors']}")
+
+        docs_and_scores = []
+        for res in result["data"]["Get"][self._index_name]:
+            text = res.pop(self._text_key)
+            score = np.dot(res["_additional"]["vector"], embedded_query)
+            docs_and_scores.append((Document(page_content=text, metadata=res), score))
+        return docs_and_scores
+
+    @classmethod
+    def from_texts(
+        cls: Type[Weaviate],
+        texts: List[str],
+        embedding: Embeddings,
+        metadatas: Optional[List[dict]] = None,
+        **kwargs: Any,
+    ) -> Weaviate:
+        """Construct Weaviate wrapper from raw documents.
+
+        This is a user-friendly interface that:
+            1. Embeds documents.
+            2. Creates a new index for the embeddings in the Weaviate instance.
+            3. Adds the documents to the newly created Weaviate index.
+
+        This is intended to be a quick way to get started.
+
+        Example:
+            .. code-block:: python
+
+                from langchain.vectorstores.weaviate import Weaviate
+                from langchain.embeddings import OpenAIEmbeddings
+                embeddings = OpenAIEmbeddings()
+                weaviate = Weaviate.from_texts(
+                    texts,
+                    embeddings,
+                    weaviate_url="http://localhost:8080"
+                )
+        """
+
+        client = _create_weaviate_client(**kwargs)
+
+        from weaviate.util import get_valid_uuid
+
+        index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
+        embeddings = embedding.embed_documents(texts) if embedding else None
+        text_key = "text"
+        schema = _default_schema(index_name)
+        attributes = list(metadatas[0].keys()) if metadatas else None
+
+        # check whether the index already exists
+        if not client.schema.contains(schema):
+            client.schema.create_class(schema)
+
+        with client.batch as batch:
+            for i, text in enumerate(texts):
+                data_properties = {
+                    text_key: text,
+                }
+                if metadatas is not None:
+                    for key in metadatas[i].keys():
+                        data_properties[key] = metadatas[i][key]
+
+                # If the UUID of one of the objects already exists
+                # then the existing objectwill be replaced by the new object.
+                if "uuids" in kwargs:
+                    _id = kwargs["uuids"][i]
+                else:
+                    _id = get_valid_uuid(uuid4())
+
+                # if an embedding strategy is not provided, we let
+                # weaviate create the embedding. Note that this will only
+                # work if weaviate has been installed with a vectorizer module
+                # like text2vec-contextionary for example
+                params = {
+                    "uuid": _id,
+                    "data_object": data_properties,
+                    "class_name": index_name,
+                }
+                if embeddings is not None:
+                    params["vector"] = embeddings[i]
+
+                batch.add_data_object(**params)
+
+            batch.flush()
+
+        relevance_score_fn = kwargs.get("relevance_score_fn")
+        by_text: bool = kwargs.get("by_text", False)
+
+        return cls(
+            client,
+            index_name,
+            text_key,
+            embedding=embedding,
+            attributes=attributes,
+            relevance_score_fn=relevance_score_fn,
+            by_text=by_text,
+        )
+
+    def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
+        """Delete by vector IDs.
+
+        Args:
+            ids: List of ids to delete.
+        """
+
+        if ids is None:
+            raise ValueError("No ids provided to delete.")
+
+        # TODO: Check if this can be done in bulk
+        for id in ids:
+            self._client.data_object.delete(uuid=id)

+ 19 - 1
api/fields/dataset_fields.py

@@ -12,6 +12,21 @@ dataset_fields = {
     'created_at': TimestampField,
 }
 
+reranking_model_fields = {
+    'reranking_provider_name': fields.String,
+    'reranking_model_name': fields.String
+}
+
+dataset_retrieval_model_fields = {
+    'search_method': fields.String,
+    'reranking_enable': fields.Boolean,
+    'reranking_model': fields.Nested(reranking_model_fields),
+    'top_k': fields.Integer,
+    'score_threshold_enable': fields.Boolean,
+    'score_threshold': fields.Float
+}
+
+
 dataset_detail_fields = {
     'id': fields.String,
     'name': fields.String,
@@ -29,7 +44,8 @@ dataset_detail_fields = {
     'updated_at': TimestampField,
     'embedding_model': fields.String,
     'embedding_model_provider': fields.String,
-    'embedding_available': fields.Boolean
+    'embedding_available': fields.Boolean,
+    'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
 }
 
 dataset_query_detail_fields = {
@@ -41,3 +57,5 @@ dataset_query_detail_fields = {
     "created_by": fields.String,
     "created_at": TimestampField
 }
+
+

+ 43 - 0
api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py

@@ -0,0 +1,43 @@
+"""add-dataset-retrival-model
+
+Revision ID: fca025d3b60f
+Revises: b3a09c049e8e
+Create Date: 2023-11-03 13:08:23.246396
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'fca025d3b60f'
+down_revision = '8fe468ba0ca5'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.drop_table('sessions')
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
+        batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin')
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('datasets', schema=None) as batch_op:
+        batch_op.drop_index('retrieval_model_idx', postgresql_using='gin')
+        batch_op.drop_column('retrieval_model')
+
+    op.create_table('sessions',
+    sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
+    sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True),
+    sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True),
+    sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
+    sa.PrimaryKeyConstraint('id', name='sessions_pkey'),
+    sa.UniqueConstraint('session_id', name='sessions_session_id_key')
+    )
+    # ### end Alembic commands ###

+ 18 - 4
api/models/dataset.py

@@ -3,7 +3,7 @@ import pickle
 from json import JSONDecodeError
 
 from sqlalchemy import func
-from sqlalchemy.dialects.postgresql import UUID
+from sqlalchemy.dialects.postgresql import UUID, JSONB
 
 from extensions.ext_database import db
 from models.account import Account
@@ -15,6 +15,7 @@ class Dataset(db.Model):
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='dataset_pkey'),
         db.Index('dataset_tenant_idx', 'tenant_id'),
+        db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
     )
 
     INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy']
@@ -39,7 +40,7 @@ class Dataset(db.Model):
     embedding_model = db.Column(db.String(255), nullable=True)
     embedding_model_provider = db.Column(db.String(255), nullable=True)
     collection_binding_id = db.Column(UUID, nullable=True)
-
+    retrieval_model = db.Column(JSONB, nullable=True)
 
     @property
     def dataset_keyword_table(self):
@@ -93,6 +94,20 @@ class Dataset(db.Model):
         return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
             .filter(Document.dataset_id == self.id).scalar()
 
+    @property
+    def retrieval_model_dict(self):
+        default_retrieval_model = {
+            'search_method': 'semantic_search',
+            'reranking_enable': False,
+            'reranking_model': {
+                'reranking_provider_name': '',
+                'reranking_model_name': ''
+            },
+            'top_k': 2,
+            'score_threshold_enable': False
+        }
+        return self.retrieval_model if self.retrieval_model else default_retrieval_model
+
 
 class DatasetProcessRule(db.Model):
     __tablename__ = 'dataset_process_rules'
@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
         ],
         'segmentation': {
             'delimiter': '\n',
-            'max_tokens': 1000
+            'max_tokens': 512
         }
     }
 
@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
     model_name = db.Column(db.String(40), nullable=False)
     collection_name = db.Column(db.String(64), nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
-

+ 7 - 1
api/models/model.py

@@ -160,7 +160,13 @@ class AppModelConfig(db.Model):
 
     @property
     def dataset_configs_dict(self) -> dict:
-        return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
+        if self.dataset_configs:
+            dataset_configs = json.loads(self.dataset_configs)
+            if 'retrieval_model' not in dataset_configs:
+                return {'retrieval_model': 'single'}
+            else:
+                return dataset_configs
+        return {'retrieval_model': 'single'}
 
     @property
     def file_upload_dict(self) -> dict:

+ 3 - 2
api/requirements.txt

@@ -23,7 +23,6 @@ boto3==1.28.17
 tenacity==8.2.2
 cachetools~=5.3.0
 weaviate-client~=3.21.0
-qdrant_client~=1.1.6
 mailchimp-transactional~=1.0.50
 scikit-learn==1.2.2
 sentry-sdk[flask]~=1.21.1
@@ -53,4 +52,6 @@ xinference-client~=0.5.4
 safetensors==0.3.2
 zhipuai==1.0.7
 werkzeug==2.3.7
-pymilvus==2.3.0
+pymilvus==2.3.0
+qdrant-client==1.6.4
+cohere~=4.32

+ 10 - 1
api/services/app_model_config_service.py

@@ -470,7 +470,16 @@ class AppModelConfigService:
 
         # dataset_configs
         if 'dataset_configs' not in config or not config["dataset_configs"]:
-            config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
+            config["dataset_configs"] = {'retrieval_model': 'single'}
+
+        if not isinstance(config["dataset_configs"], dict):
+            raise ValueError("dataset_configs must be of object type")
+
+        if config["dataset_configs"]['retrieval_model'] == 'multiple':
+            if not config["dataset_configs"]['reranking_model']:
+                raise ValueError("reranking_model has not been set")
+            if not isinstance(config["dataset_configs"]['reranking_model'], dict):
+                raise ValueError("reranking_model must be of object type")
 
         if not isinstance(config["dataset_configs"], dict):
             raise ValueError("dataset_configs must be of object type")

+ 33 - 2
api/services/dataset_service.py

@@ -173,6 +173,9 @@ class DatasetService:
         filtered_data['updated_by'] = user.id
         filtered_data['updated_at'] = datetime.datetime.now()
 
+        # update Retrieval model
+        filtered_data['retrieval_model'] = data['retrieval_model']
+
         dataset.query.filter_by(id=dataset_id).update(filtered_data)
 
         db.session.commit()
@@ -473,7 +476,19 @@ class DocumentService:
                     embedding_model.name
                 )
                 dataset.collection_binding_id = dataset_collection_binding.id
+                if not dataset.retrieval_model:
+                    default_retrieval_model = {
+                        'search_method': 'semantic_search',
+                        'reranking_enable': False,
+                        'reranking_model': {
+                            'reranking_provider_name': '',
+                            'reranking_model_name': ''
+                        },
+                        'top_k': 2,
+                        'score_threshold_enable': False
+                    }
 
+                    dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
 
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
@@ -733,6 +748,7 @@ class DocumentService:
                 raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
         embedding_model = None
         dataset_collection_binding_id = None
+        retrieval_model = None
         if document_data['indexing_technique'] == 'high_quality':
             embedding_model = ModelFactory.get_embedding_model(
                 tenant_id=tenant_id
@@ -742,6 +758,20 @@ class DocumentService:
                 embedding_model.name
             )
             dataset_collection_binding_id = dataset_collection_binding.id
+            if 'retrieval_model' in document_data and document_data['retrieval_model']:
+                retrieval_model = document_data['retrieval_model']
+            else:
+                default_retrieval_model = {
+                    'search_method': 'semantic_search',
+                    'reranking_enable': False,
+                    'reranking_model': {
+                        'reranking_provider_name': '',
+                        'reranking_model_name': ''
+                    },
+                    'top_k': 2,
+                    'score_threshold_enable': False
+                }
+                retrieval_model = default_retrieval_model
         # save dataset
         dataset = Dataset(
             tenant_id=tenant_id,
@@ -751,7 +781,8 @@ class DocumentService:
             created_by=account.id,
             embedding_model=embedding_model.name if embedding_model else None,
             embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
-            collection_binding_id=dataset_collection_binding_id
+            collection_binding_id=dataset_collection_binding_id,
+            retrieval_model=retrieval_model
         )
 
         db.session.add(dataset)
@@ -768,7 +799,7 @@ class DocumentService:
         return dataset, documents, batch
 
     @classmethod
-    def  document_create_args_validate(cls, args: dict):
+    def document_create_args_validate(cls, args: dict):
         if 'original_document_id' not in args or not args['original_document_id']:
             DocumentService.data_source_args_validate(args)
             DocumentService.process_rule_args_validate(args)

+ 79 - 22
api/services/hit_testing_service.py

@@ -1,4 +1,6 @@
+import json
 import logging
+import threading
 import time
 from typing import List
 
@@ -9,16 +11,26 @@ from langchain.schema import Document
 from sklearn.manifold import TSNE
 
 from core.embedding.cached_embedding import CacheEmbedding
-from core.index.vector_index.vector_index import VectorIndex
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
 from models.account import Account
 from models.dataset import Dataset, DocumentSegment, DatasetQuery
-
+from services.retrieval_service import RetrievalService
+
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enable': False
+}
 
 class HitTestingService:
     @classmethod
-    def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
+    def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
         if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
             return {
                 "query": {
@@ -28,31 +40,68 @@ class HitTestingService:
                 "records": []
             }
 
+        start = time.perf_counter()
+
+        # get retrieval model , if the model is not setting , using default
+        if not retrieval_model:
+            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
+
+        # get embedding model
         embedding_model = ModelFactory.get_embedding_model(
             tenant_id=dataset.tenant_id,
             model_provider_name=dataset.embedding_model_provider,
             model_name=dataset.embedding_model
         )
-
         embeddings = CacheEmbedding(embedding_model)
 
-        vector_index = VectorIndex(
-            dataset=dataset,
-            config=current_app.config,
-            embeddings=embeddings
-        )
+        all_documents = []
+        threads = []
+
+        # retrieval_model source with semantic
+        if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
+            embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset': dataset,
+                'query': query,
+                'top_k': retrieval_model['top_k'],
+                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
+                'all_documents': all_documents,
+                'search_method': retrieval_model['search_method'],
+                'embeddings': embeddings
+            })
+            threads.append(embedding_thread)
+            embedding_thread.start()
+
+        # retrieval source with full text
+        if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
+            full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
+                'flask_app': current_app._get_current_object(),
+                'dataset': dataset,
+                'query': query,
+                'search_method': retrieval_model['search_method'],
+                'embeddings': embeddings,
+                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                'top_k': retrieval_model['top_k'],
+                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
+                'all_documents': all_documents
+            })
+            threads.append(full_text_index_thread)
+            full_text_index_thread.start()
+
+        for thread in threads:
+            thread.join()
+
+        if retrieval_model['search_method'] == 'hybrid_search':
+            hybrid_rerank = ModelFactory.get_reranking_model(
+                tenant_id=dataset.tenant_id,
+                model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
+                model_name=retrieval_model['reranking_model']['reranking_model_name']
+            )
+            all_documents = hybrid_rerank.rerank(query, all_documents,
+                                                 retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
+                                                 retrieval_model['top_k'])
 
-        start = time.perf_counter()
-        documents = vector_index.search(
-            query,
-            search_type='similarity_score_threshold',
-            search_kwargs={
-                'k': 10,
-                'filter': {
-                    'group_id': [dataset.id]
-                }
-            }
-        )
         end = time.perf_counter()
         logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
 
@@ -67,7 +116,7 @@ class HitTestingService:
         db.session.add(dataset_query)
         db.session.commit()
 
-        return cls.compact_retrieve_response(dataset, embeddings, query, documents)
+        return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)
 
     @classmethod
     def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):
@@ -99,7 +148,7 @@ class HitTestingService:
 
             record = {
                 "segment": segment,
-                "score": document.metadata['score'],
+                "score": document.metadata.get('score', None),
                 "tsne_position": tsne_position_data[i]
             }
 
@@ -136,3 +185,11 @@ class HitTestingService:
             tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})
 
         return tsne_position_data
+
+    @classmethod
+    def hit_testing_args_check(cls, args):
+        query = args['query']
+
+        if not query or len(query) > 250:
+            raise ValueError('Query is required and cannot exceed 250 characters')
+

+ 88 - 0
api/services/retrieval_service.py

@@ -0,0 +1,88 @@
+
+from typing import Optional
+from flask import current_app, Flask
+from langchain.embeddings.base import Embeddings
+from core.index.vector_index.vector_index import VectorIndex
+from core.model_providers.model_factory import ModelFactory
+from models.dataset import Dataset
+
+default_retrieval_model = {
+    'search_method': 'semantic_search',
+    'reranking_enable': False,
+    'reranking_model': {
+        'reranking_provider_name': '',
+        'reranking_model_name': ''
+    },
+    'top_k': 2,
+    'score_threshold_enable': False
+}
+
+
+class RetrievalService:
+
+    @classmethod
+    def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
+                         top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
+                         all_documents: list, search_method: str, embeddings: Embeddings):
+        with flask_app.app_context():
+
+            vector_index = VectorIndex(
+                dataset=dataset,
+                config=current_app.config,
+                embeddings=embeddings
+            )
+
+            documents = vector_index.search(
+                query,
+                search_type='similarity_score_threshold',
+                search_kwargs={
+                    'k': top_k,
+                    'score_threshold': score_threshold,
+                    'filter': {
+                        'group_id': [dataset.id]
+                    }
+                }
+            )
+
+            if documents:
+                if reranking_model and search_method == 'semantic_search':
+                    rerank = ModelFactory.get_reranking_model(
+                        tenant_id=dataset.tenant_id,
+                        model_provider_name=reranking_model['reranking_provider_name'],
+                        model_name=reranking_model['reranking_model_name']
+                    )
+                    all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
+                else:
+                    all_documents.extend(documents)
+
+    @classmethod
+    def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
+                               top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
+                               all_documents: list, search_method: str, embeddings: Embeddings):
+        with flask_app.app_context():
+
+            vector_index = VectorIndex(
+                dataset=dataset,
+                config=current_app.config,
+                embeddings=embeddings
+            )
+
+            documents = vector_index.search_by_full_text_index(
+                query,
+                search_type='similarity_score_threshold',
+                top_k=top_k
+            )
+            if documents:
+                if reranking_model and search_method == 'full_text_search':
+                    rerank = ModelFactory.get_reranking_model(
+                        tenant_id=dataset.tenant_id,
+                        model_provider_name=reranking_model['reranking_provider_name'],
+                        model_name=reranking_model['reranking_model_name']
+                    )
+                    all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
+                else:
+                    all_documents.extend(documents)
+
+
+
+