ソースを参照

Feat:dataset retiever resource (#1123)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Jyong 1 年間 前
コミット
642842d61b
32 ファイル変更442 行追加33 行削除
  1. 1 0
      api/controllers/console/app/app.py
  2. 2 0
      api/controllers/console/app/completion.py
  3. 2 0
      api/controllers/console/explore/completion.py
  4. 20 0
      api/controllers/console/explore/message.py
  5. 2 0
      api/controllers/console/explore/parameter.py
  6. 2 0
      api/controllers/console/universal_chat/chat.py
  7. 20 0
      api/controllers/console/universal_chat/message.py
  8. 5 0
      api/controllers/console/universal_chat/parameter.py
  9. 1 0
      api/controllers/console/universal_chat/wraps.py
  10. 2 0
      api/controllers/service_api/app/app.py
  11. 4 0
      api/controllers/service_api/app/completion.py
  12. 19 0
      api/controllers/service_api/app/message.py
  13. 2 0
      api/controllers/web/app.py
  14. 4 0
      api/controllers/web/completion.py
  15. 20 0
      api/controllers/web/message.py
  16. 5 0
      api/core/agent/agent/multi_dataset_router_agent.py
  17. 1 4
      api/core/callback_handler/dataset_tool_callback_handler.py
  18. 7 1
      api/core/callback_handler/index_tool_callback_handler.py
  19. 6 4
      api/core/completion.py
  20. 53 3
      api/core/conversation_message_task.py
  21. 1 1
      api/core/index/keyword_table_index/keyword_table_index.py
  22. 19 0
      api/core/index/vector_index/qdrant_vector_index.py
  23. 1 0
      api/core/model_providers/models/entity/message.py
  24. 20 9
      api/core/orchestrator_rule_parser.py
  25. 1 1
      api/core/prompt/generate_prompts/common_chat.json
  26. 1 1
      api/core/prompt/prompts.py
  27. 48 4
      api/core/tool/dataset_retriever_tool.py
  28. 54 0
      api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py
  29. 32 0
      api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
  30. 44 1
      api/models/model.py
  31. 16 0
      api/services/app_model_config_service.py
  32. 27 4
      api/services/completion_service.py

+ 1 - 0
api/controllers/console/app/app.py

@@ -29,6 +29,7 @@ model_config_fields = {
     'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
     'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
     'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
+    'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
     'more_like_this': fields.Raw(attribute='more_like_this_dict'),
     'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
     'model': fields.Raw(attribute='model_dict'),

+ 2 - 0
api/controllers/console/app/completion.py

@@ -42,6 +42,7 @@ class CompletionMessageApi(Resource):
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         args = parser.parse_args()
 
         streaming = args['response_mode'] != 'blocking'
@@ -115,6 +116,7 @@ class ChatMessageApi(Resource):
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         args = parser.parse_args()
 
         streaming = args['response_mode'] != 'blocking'

+ 2 - 0
api/controllers/console/explore/completion.py

@@ -33,6 +33,7 @@ class CompletionApi(InstalledAppResource):
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
@@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource):
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'

+ 20 - 0
api/controllers/console/explore/message.py

@@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource):
         'rating': fields.String
     }
 
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
+
     message_fields = {
         'id': fields.String,
         'conversation_id': fields.String,
@@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource):
         'query': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
     }
 

+ 2 - 0
api/controllers/console/explore/parameter.py

@@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource):
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
     }
@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
         }

+ 2 - 0
api/controllers/console/universal_chat/chat.py

@@ -29,9 +29,11 @@ class UniversalChatApi(UniversalChatResource):
         parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
         parser.add_argument('tools', type=list, required=True, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
         args = parser.parse_args()
 
         app_model_config = app_model.app_model_config
+        app_model_config
 
         # update app model config
         args['model_config'] = app_model_config.to_dict()

+ 20 - 0
api/controllers/console/universal_chat/message.py

@@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource):
         'created_at': TimestampField
     }
 
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
+
     message_fields = {
         'id': fields.String,
         'conversation_id': fields.String,
@@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource):
         'query': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField,
         'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
     }

+ 5 - 0
api/controllers/console/universal_chat/parameter.py

@@ -1,4 +1,6 @@
 # -*- coding:utf-8 -*-
+import json
+
 from flask_restful import marshal_with, fields
 
 from controllers.console import api
@@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource):
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
     }
 
     @marshal_with(parameters_fields)
@@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource):
         """Retrieve app parameters."""
         app_model = universal_app
         app_model_config = app_model.app_model_config
+        app_model_config.retriever_resource = json.dumps({'enabled': True})
 
         return {
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
         }
 
 

+ 1 - 0
api/controllers/console/universal_chat/wraps.py

@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
                     suggested_questions=json.dumps([]),
                     suggested_questions_after_answer=json.dumps({'enabled': True}),
                     speech_to_text=json.dumps({'enabled': True}),
+                    retriever_resource=json.dumps({'enabled': True}),
                     more_like_this=None,
                     sensitive_word_avoidance=None,
                     model=json.dumps({

+ 2 - 0
api/controllers/service_api/app/app.py

@@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource):
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
     }
@@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource):
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
         }

+ 4 - 0
api/controllers/service_api/app/completion.py

@@ -30,6 +30,8 @@ class CompletionApi(AppApiResource):
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('user', type=str, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
@@ -91,6 +93,8 @@ class ChatApi(AppApiResource):
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('user', type=str, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'

+ 19 - 0
api/controllers/service_api/app/message.py

@@ -16,6 +16,24 @@ class MessageListApi(AppApiResource):
     feedback_fields = {
         'rating': fields.String
     }
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
 
     message_fields = {
         'id': fields.String,
@@ -24,6 +42,7 @@ class MessageListApi(AppApiResource):
         'query': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
     }
 

+ 2 - 0
api/controllers/web/app.py

@@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource):
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
     }
@@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource):
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
         }

+ 4 - 0
api/controllers/web/completion.py

@@ -31,6 +31,8 @@ class CompletionApi(WebApiResource):
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
+
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
@@ -88,6 +90,8 @@ class ChatApi(WebApiResource):
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
+
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'

+ 20 - 0
api/controllers/web/message.py

@@ -29,6 +29,25 @@ class MessageListApi(WebApiResource):
         'rating': fields.String
     }
 
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
+
     message_fields = {
         'id': fields.String,
         'conversation_id': fields.String,
@@ -36,6 +55,7 @@ class MessageListApi(WebApiResource):
         'query': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
     }
 

+ 5 - 0
api/core/agent/agent/multi_dataset_router_agent.py

@@ -1,3 +1,4 @@
+import json
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             tool = next(iter(self.tools))
             tool = cast(DatasetRetrieverTool, tool)
             rst = tool.run(tool_input={'query': kwargs['input']})
+            # output = ''
+            # rst_json = json.loads(rst)
+            # for item in rst_json:
+            #     output += f'{item["content"]}\n'
             return AgentFinish(return_values={"output": rst}, log=rst)
 
         if intermediate_steps:

+ 1 - 4
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
         llm_prefix: Optional[str] = None,
         **kwargs: Any,
     ) -> None:
-        # kwargs={'name': 'Search'}
-        # llm_prefix='Thought:'
-        # observation_prefix='Observation: '
-        # output='53 years'
         pass
 
+
     def on_tool_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:

+ 7 - 1
api/core/callback_handler/index_tool_callback_handler.py

@@ -2,6 +2,7 @@ from typing import List
 
 from langchain.schema import Document
 
+from core.conversation_message_task import ConversationMessageTask
 from extensions.ext_database import db
 from models.dataset import DocumentSegment
 
@@ -9,8 +10,9 @@ from models.dataset import DocumentSegment
 class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
 
-    def __init__(self, dataset_id: str) -> None:
+    def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
         self.dataset_id = dataset_id
+        self.conversation_message_task = conversation_message_task
 
     def on_tool_end(self, documents: List[Document]) -> None:
         """Handle tool end."""
@@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler:
             )
 
             db.session.commit()
+
+    def return_retriever_resource_info(self, resource: List):
+        """Handle return_retriever_resource_info."""
+        self.conversation_message_task.on_dataset_query_finish(resource)

+ 6 - 4
api/core/completion.py

@@ -1,3 +1,4 @@
+import json
 import logging
 import re
 from typing import Optional, List, Union, Tuple
@@ -19,13 +20,15 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
+from models.dataset import DocumentSegment, Dataset, Document
 from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
 
 
 class Completion:
     @classmethod
     def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
-                 user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
+                 user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
+                 is_override: bool = False, retriever_from: str = 'dev'):
         """
         errors: ProviderTokenNotInitError
         """
@@ -96,7 +99,6 @@ class Completion:
             should_use_agent = agent_executor.should_use_agent(query)
             if should_use_agent:
                 agent_execute_result = agent_executor.run(query)
-
         # run the final llm
         try:
             cls.run_final_llm(
@@ -118,7 +120,8 @@ class Completion:
             return
 
     @classmethod
-    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
+    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
+                      inputs: dict,
                       agent_execute_result: Optional[AgentExecuteResult],
                       conversation_message_task: ConversationMessageTask,
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
@@ -150,7 +153,6 @@ class Completion:
             callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
             fake_response=fake_response
         )
-
         return response
 
     @classmethod

+ 53 - 3
api/core/conversation_message_task.py

@@ -1,6 +1,6 @@
 import decimal
 import json
-from typing import Optional, Union
+from typing import Optional, Union, List
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
@@ -15,7 +15,8 @@ from events.message_event import message_was_created
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import DatasetQuery
-from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
+from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
+    MessageChain, DatasetRetrieverResource
 
 
 class ConversationMessageTask:
@@ -41,6 +42,8 @@ class ConversationMessageTask:
 
         self.message = None
 
+        self.retriever_resource = None
+
         self.model_dict = self.app_model_config.model_dict
         self.provider_name = self.model_dict.get('provider')
         self.model_name = self.model_dict.get('name')
@@ -157,7 +160,8 @@ class ConversationMessageTask:
         self.message.message_tokens = message_tokens
         self.message.message_unit_price = message_unit_price
         self.message.message_price_unit = message_price_unit
-        self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
+        self.message.answer = PromptBuilder.process_template(
+            llm_message.completion.strip()) if llm_message.completion else ''
         self.message.answer_tokens = answer_tokens
         self.message.answer_unit_price = answer_unit_price
         self.message.answer_price_unit = answer_price_unit
@@ -256,7 +260,36 @@ class ConversationMessageTask:
 
         db.session.add(dataset_query)
 
+    def on_dataset_query_finish(self, resource: List):
+        if resource and len(resource) > 0:
+            for item in resource:
+                dataset_retriever_resource = DatasetRetrieverResource(
+                    message_id=self.message.id,
+                    position=item.get('position'),
+                    dataset_id=item.get('dataset_id'),
+                    dataset_name=item.get('dataset_name'),
+                    document_id=item.get('document_id'),
+                    document_name=item.get('document_name'),
+                    data_source_type=item.get('data_source_type'),
+                    segment_id=item.get('segment_id'),
+                    score=item.get('score') if 'score' in item else None,
+                    hit_count=item.get('hit_count') if 'hit_count' else None,
+                    word_count=item.get('word_count') if 'word_count' in item else None,
+                    segment_position=item.get('segment_position') if 'segment_position' in item else None,
+                    index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
+                    content=item.get('content'),
+                    retriever_from=item.get('retriever_from'),
+                    created_by=self.user.id
+                )
+                db.session.add(dataset_retriever_resource)
+                db.session.flush()
+            self.retriever_resource = resource
+
+    def message_end(self):
+        self._pub_handler.pub_message_end(self.retriever_resource)
+
     def end(self):
+        self._pub_handler.pub_message_end(self.retriever_resource)
         self._pub_handler.pub_end()
 
 
@@ -350,6 +383,23 @@ class PubHandler:
             self.pub_end()
             raise ConversationTaskStoppedException()
 
+    def pub_message_end(self, retriever_resource: List):
+        content = {
+            'event': 'message_end',
+            'data': {
+                'task_id': self._task_id,
+                'message_id': self._message.id,
+                'mode': self._conversation.mode,
+                'conversation_id': self._conversation.id
+            }
+        }
+        if retriever_resource:
+            content['data']['retriever_resources'] = retriever_resource
+        redis_client.publish(self._channel, json.dumps(content))
+
+        if self._is_stopped():
+            self.pub_end()
+            raise ConversationTaskStoppedException()
 
     def pub_end(self):
         content = {

+ 1 - 1
api/core/index/keyword_table_index/keyword_table_index.py

@@ -74,7 +74,7 @@ class KeywordTableIndex(BaseIndex):
             DocumentSegment.document_id == document_id
         ).all()
 
-        ids = [segment.id for segment in segments]
+        ids = [segment.index_node_id for segment in segments]
 
         keyword_table = self._get_dataset_keyword_table()
         keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)

+ 19 - 0
api/core/index/vector_index/qdrant_vector_index.py

@@ -113,6 +113,25 @@ class QdrantVectorIndex(BaseVectorIndex):
             ],
         ))
 
+    def delete_by_ids(self, ids: list[str]) -> None:
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        from qdrant_client.http import models
+        for node_id in ids:
+            vector_store.del_texts(models.Filter(
+                must=[
+                    models.FieldCondition(
+                        key="metadata.doc_id",
+                        match=models.MatchValue(value=node_id),
+                    ),
+                ],
+            ))
+
     def _is_origin(self):
         if self.dataset.index_struct_dict:
             class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

+ 1 - 0
api/core/model_providers/models/entity/message.py

@@ -8,6 +8,7 @@ class LLMRunResult(BaseModel):
     content: str
     prompt_tokens: int
     completion_tokens: int
+    source: list = None
 
 
 class MessageType(enum.Enum):

+ 20 - 9
api/core/orchestrator_rule_parser.py

@@ -36,8 +36,8 @@ class OrchestratorRuleParser:
         self.app_model_config = app_model_config
 
     def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
-                          rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
-            -> Optional[AgentExecutor]:
+                          rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
+                          return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
         if not self.app_model_config.agent_mode_dict:
             return None
 
@@ -74,7 +74,7 @@ class OrchestratorRuleParser:
 
             # only OpenAI chat model (include Azure) support function call, use ReACT instead
             if agent_model_instance.model_mode != ModelMode.CHAT \
-                         or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
+                    or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
                 if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
                     planning_strategy = PlanningStrategy.REACT
                 elif planning_strategy == PlanningStrategy.ROUTER:
@@ -99,7 +99,9 @@ class OrchestratorRuleParser:
                 tool_configs=tool_configs,
                 conversation_message_task=conversation_message_task,
                 rest_tokens=rest_tokens,
-                callbacks=[agent_callback, DifyStdOutCallbackHandler()]
+                callbacks=[agent_callback, DifyStdOutCallbackHandler()],
+                return_resource=return_resource,
+                retriever_from=retriever_from
             )
 
             if len(tools) == 0:
@@ -145,8 +147,10 @@ class OrchestratorRuleParser:
 
         return None
 
-    def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
-                 rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
+    def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
+                 conversation_message_task: ConversationMessageTask,
+                 rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
+                 retriever_from: str = 'dev') -> list[BaseTool]:
         """
         Convert app agent tool configs to tools
 
@@ -155,6 +159,8 @@ class OrchestratorRuleParser:
         :param tool_configs: app agent tool configs
         :param conversation_message_task:
         :param callbacks:
+        :param return_resource:
+        :param retriever_from:
         :return:
         """
         tools = []
@@ -166,7 +172,7 @@ class OrchestratorRuleParser:
 
             tool = None
             if tool_type == "dataset":
-                tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
+                tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
             elif tool_type == "web_reader":
                 tool = self.to_web_reader_tool(agent_model_instance)
             elif tool_type == "google_search":
@@ -183,13 +189,15 @@ class OrchestratorRuleParser:
         return tools
 
     def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
-                                  rest_tokens: int) \
+                                  rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
             -> Optional[BaseTool]:
         """
         A dataset tool is a tool that can be used to retrieve information from a dataset
         :param rest_tokens:
         :param tool_config:
         :param conversation_message_task:
+        :param return_resource:
+        :param retriever_from:
         :return:
         """
         # get dataset from dataset id
@@ -208,7 +216,10 @@ class OrchestratorRuleParser:
         tool = DatasetRetrieverTool.from_dataset(
             dataset=dataset,
             k=k,
-            callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
+            callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
+            conversation_message_task=conversation_message_task,
+            return_resource=return_resource,
+            retriever_from=retriever_from
         )
 
         return tool

+ 1 - 1
api/core/prompt/generate_prompts/common_chat.json

@@ -10,4 +10,4 @@
   ],
   "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ",
   "stops": ["\nHuman:", "</histories>"]
-}
+}

+ 1 - 1
api/core/prompt/prompts.py

@@ -105,7 +105,7 @@ GENERATOR_QA_PROMPT = (
     'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
     'Step 4: Generate 20 questions and answers based on these key information and concepts.'
     'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
-    "Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
+    "Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
 )
 
 RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \

+ 48 - 4
api/core/tool/dataset_retriever_tool.py

@@ -1,3 +1,4 @@
+import json
 from typing import Type
 
 from flask import current_app
@@ -5,13 +6,14 @@ from langchain.tools import BaseTool
 from pydantic import Field, BaseModel
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.conversation_message_task import ConversationMessageTask
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.vector_index.vector_index import VectorIndex
 from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
-from models.dataset import Dataset, DocumentSegment
+from models.dataset import Dataset, DocumentSegment, Document
 
 
 class DatasetRetrieverToolInput(BaseModel):
@@ -27,6 +29,10 @@ class DatasetRetrieverTool(BaseTool):
     tenant_id: str
     dataset_id: str
     k: int = 3
+    conversation_message_task: ConversationMessageTask
+    return_resource: str
+    retriever_from: str
+
 
     @classmethod
     def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -86,7 +92,7 @@ class DatasetRetrieverTool(BaseTool):
             if self.k > 0:
                 documents = vector_index.search(
                     query,
-                    search_type='similarity',
+                    search_type='similarity_score_threshold',
                     search_kwargs={
                         'k': self.k
                     }
@@ -94,8 +100,12 @@ class DatasetRetrieverTool(BaseTool):
             else:
                 documents = []
 
-            hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
+            hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
             hit_callback.on_tool_end(documents)
+            document_score_list = {}
+            if dataset.indexing_technique != "economy":
+                for item in documents:
+                    document_score_list[item.metadata['doc_id']] = item.metadata['score']
             document_context_list = []
             index_node_ids = [document.metadata['doc_id'] for document in documents]
             segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
@@ -112,9 +122,43 @@ class DatasetRetrieverTool(BaseTool):
                                                                                            float('inf')))
                 for segment in sorted_segments:
                     if segment.answer:
-                        document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
+                        document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
                     else:
                         document_context_list.append(segment.content)
+                if self.return_resource:
+                    context_list = []
+                    resource_number = 1
+                    for segment in sorted_segments:
+                        context = {}
+                        document = Document.query.filter(Document.id == segment.document_id,
+                                                         Document.enabled == True,
+                                                         Document.archived == False,
+                                                         ).first()
+                        if dataset and document:
+                            source = {
+                                'position': resource_number,
+                                'dataset_id': dataset.id,
+                                'dataset_name': dataset.name,
+                                'document_id': document.id,
+                                'document_name': document.name,
+                                'data_source_type': document.data_source_type,
+                                'segment_id': segment.id,
+                                'retriever_from': self.retriever_from
+                            }
+                            if dataset.indexing_technique != "economy":
+                                source['score'] = document_score_list.get(segment.index_node_id)
+                            if self.retriever_from == 'dev':
+                                source['hit_count'] = segment.hit_count
+                                source['word_count'] = segment.word_count
+                                source['segment_position'] = segment.position
+                                source['index_node_hash'] = segment.index_node_hash
+                            if segment.answer:
+                                source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
+                            else:
+                                source['content'] = segment.content
+                            context_list.append(source)
+                        resource_number += 1
+                    hit_callback.return_retriever_resource_info(context_list)
 
             return str("\n".join(document_context_list))
 

+ 54 - 0
api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py

@@ -0,0 +1,54 @@
+"""add_dataset_retriever_resource
+
+Revision ID: 6dcb43972bdc
+Revises: 4bcffcd64aa4
+Create Date: 2023-09-06 16:51:27.385844
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '6dcb43972bdc'
+down_revision = '4bcffcd64aa4'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('dataset_retriever_resources',
+    sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('message_id', postgresql.UUID(), nullable=False),
+    sa.Column('position', sa.Integer(), nullable=False),
+    sa.Column('dataset_id', postgresql.UUID(), nullable=False),
+    sa.Column('dataset_name', sa.Text(), nullable=False),
+    sa.Column('document_id', postgresql.UUID(), nullable=False),
+    sa.Column('document_name', sa.Text(), nullable=False),
+    sa.Column('data_source_type', sa.Text(), nullable=False),
+    sa.Column('segment_id', postgresql.UUID(), nullable=False),
+    sa.Column('score', sa.Float(), nullable=True),
+    sa.Column('content', sa.Text(), nullable=False),
+    sa.Column('hit_count', sa.Integer(), nullable=True),
+    sa.Column('word_count', sa.Integer(), nullable=True),
+    sa.Column('segment_position', sa.Integer(), nullable=True),
+    sa.Column('index_node_hash', sa.Text(), nullable=True),
+    sa.Column('retriever_from', sa.Text(), nullable=False),
+    sa.Column('created_by', postgresql.UUID(), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey')
+    )
+    with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+        batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+        batch_op.drop_index('dataset_retriever_resource_message_id_idx')
+
+    op.drop_table('dataset_retriever_resources')
+    # ### end Alembic commands ###

+ 32 - 0
api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py

@@ -0,0 +1,32 @@
+"""add_app_config_retriever_resource
+
+Revision ID: 77e83833755c
+Revises: 6dcb43972bdc
+Create Date: 2023-09-06 17:26:40.311927
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '77e83833755c'
+down_revision = '6dcb43972bdc'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.drop_column('retriever_resource')
+
+    # ### end Alembic commands ###

+ 44 - 1
api/models/model.py

@@ -1,4 +1,5 @@
 import json
+from json import JSONDecodeError
 
 from flask import current_app, request
 from flask_login import UserMixin
@@ -90,6 +91,7 @@ class AppModelConfig(db.Model):
     pre_prompt = db.Column(db.Text)
     agent_mode = db.Column(db.Text)
     sensitive_word_avoidance = db.Column(db.Text)
+    retriever_resource = db.Column(db.Text)
 
     @property
     def app(self):
@@ -114,6 +116,11 @@ class AppModelConfig(db.Model):
         return json.loads(self.speech_to_text) if self.speech_to_text \
             else {"enabled": False}
 
+    @property
+    def retriever_resource_dict(self) -> dict:
+        return json.loads(self.retriever_resource) if self.retriever_resource \
+            else {"enabled": False}
+
     @property
     def more_like_this_dict(self) -> dict:
         return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@@ -140,6 +147,7 @@ class AppModelConfig(db.Model):
             "suggested_questions": self.suggested_questions_list,
             "suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
             "speech_to_text": self.speech_to_text_dict,
+            "retriever_resource": self.retriever_resource,
             "more_like_this": self.more_like_this_dict,
             "sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
             "model": self.model_dict,
@@ -164,7 +172,8 @@ class AppModelConfig(db.Model):
         self.user_input_form = json.dumps(model_config['user_input_form'])
         self.pre_prompt = model_config['pre_prompt']
         self.agent_mode = json.dumps(model_config['agent_mode'])
-
+        self.retriever_resource = json.dumps(model_config['retriever_resource']) \
+            if model_config.get('retriever_resource') else None
         return self
 
     def copy(self):
@@ -318,6 +327,7 @@ class Conversation(db.Model):
             model_config['suggested_questions'] = app_model_config.suggested_questions_list
             model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
             model_config['speech_to_text'] = app_model_config.speech_to_text_dict
+            model_config['retriever_resource'] = app_model_config.retriever_resource_dict
             model_config['more_like_this'] = app_model_config.more_like_this_dict
             model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
             model_config['user_input_form'] = app_model_config.user_input_form_list
@@ -476,6 +486,11 @@ class Message(db.Model):
         return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
             .order_by(MessageAgentThought.position.asc()).all()
 
+    @property
+    def retriever_resources(self):
+        return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
+            .order_by(DatasetRetrieverResource.position.asc()).all()
+
 
 class MessageFeedback(db.Model):
     __tablename__ = 'message_feedbacks'
@@ -719,3 +734,31 @@ class MessageAgentThought(db.Model):
     created_by_role = db.Column(db.String, nullable=False)
     created_by = db.Column(UUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+
+
+class DatasetRetrieverResource(db.Model):
+    __tablename__ = 'dataset_retriever_resources'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'),
+        db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
+    )
+
+    id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(UUID, nullable=False)
+    position = db.Column(db.Integer, nullable=False)
+    dataset_id = db.Column(UUID, nullable=False)
+    dataset_name = db.Column(db.Text, nullable=False)
+    document_id = db.Column(UUID, nullable=False)
+    document_name = db.Column(db.Text, nullable=False)
+    data_source_type = db.Column(db.Text, nullable=False)
+    segment_id = db.Column(UUID, nullable=False)
+    score = db.Column(db.Float, nullable=True)
+    content = db.Column(db.Text, nullable=False)
+    hit_count = db.Column(db.Integer, nullable=True)
+    word_count = db.Column(db.Integer, nullable=True)
+    segment_position = db.Column(db.Integer, nullable=True)
+    index_node_hash = db.Column(db.Text, nullable=True)
+    retriever_from = db.Column(db.Text, nullable=False)
+    created_by = db.Column(UUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+

+ 16 - 0
api/services/app_model_config_service.py

@@ -130,6 +130,21 @@ class AppModelConfigService:
         if not isinstance(config["speech_to_text"]["enabled"], bool):
             raise ValueError("enabled in speech_to_text must be of boolean type")
 
+        # return retriever resource
+        if 'retriever_resource' not in config or not config["retriever_resource"]:
+            config["retriever_resource"] = {
+                "enabled": False
+            }
+
+        if not isinstance(config["retriever_resource"], dict):
+            raise ValueError("retriever_resource must be of dict type")
+
+        if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]:
+            config["retriever_resource"]["enabled"] = False
+
+        if not isinstance(config["retriever_resource"]["enabled"], bool):
+            raise ValueError("enabled in speech_to_text must be of boolean type")
+
         # more_like_this
         if 'more_like_this' not in config or not config["more_like_this"]:
             config["more_like_this"] = {
@@ -327,6 +342,7 @@ class AppModelConfigService:
             "suggested_questions": config["suggested_questions"],
             "suggested_questions_after_answer": config["suggested_questions_after_answer"],
             "speech_to_text": config["speech_to_text"],
+            "retriever_resource": config["retriever_resource"],
             "more_like_this": config["more_like_this"],
             "sensitive_word_avoidance": config["sensitive_word_avoidance"],
             "model": {

+ 27 - 4
api/services/completion_service.py

@@ -11,7 +11,8 @@ from sqlalchemy import and_
 
 from core.completion import Completion
 from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
-from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, \
     LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
@@ -95,6 +96,7 @@ class CompletionService:
 
                 app_model_config_model = app_model_config.model_dict
                 app_model_config_model['completion_params'] = completion_params
+                app_model_config.retriever_resource = json.dumps({'enabled': True})
 
                 app_model_config = app_model_config.copy()
                 app_model_config.model = json.dumps(app_model_config_model)
@@ -145,7 +147,8 @@ class CompletionService:
             'user': user,
             'conversation': conversation,
             'streaming': streaming,
-            'is_model_config_override': is_model_config_override
+            'is_model_config_override': is_model_config_override,
+            'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
         })
 
         generate_worker_thread.start()
@@ -169,7 +172,8 @@ class CompletionService:
     @classmethod
     def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
                         query: str, inputs: dict, user: Union[Account, EndUser],
-                        conversation: Conversation, streaming: bool, is_model_config_override: bool):
+                        conversation: Conversation, streaming: bool, is_model_config_override: bool,
+                        retriever_from: str = 'dev'):
         with flask_app.app_context():
             try:
                 if conversation:
@@ -188,6 +192,7 @@ class CompletionService:
                     conversation=conversation,
                     streaming=streaming,
                     is_override=is_model_config_override,
+                    retriever_from=retriever_from
                 )
             except ConversationTaskStoppedException:
                 pass
@@ -400,7 +405,11 @@ class CompletionService:
                             elif event == 'chain':
                                 yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
                             elif event == 'agent_thought':
-                                yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
+                                yield "data: " + json.dumps(
+                                    cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
+                            elif event == 'message_end':
+                                yield "data: " + json.dumps(
+                                    cls.get_message_end_data(result.get('data'))) + "\n\n"
                             elif event == 'ping':
                                 yield "event: ping\n\n"
                             else:
@@ -432,6 +441,20 @@ class CompletionService:
 
         return response_data
 
+    @classmethod
+    def get_message_end_data(cls, data: dict):
+        response_data = {
+            'event': 'message_end',
+            'task_id': data.get('task_id'),
+            'id': data.get('message_id')
+        }
+        if 'retriever_resources' in data:
+            response_data['retriever_resources'] = data.get('retriever_resources')
+        if data.get('mode') == 'chat':
+            response_data['conversation_id'] = data.get('conversation_id')
+
+        return response_data
+
     @classmethod
     def get_chain_response_data(cls, data: dict):
         response_data = {