Browse Source

Feat:dataset retiever resource (#1123)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Jyong 1 year ago
parent
commit
642842d61b
32 changed files with 442 additions and 33 deletions
  1. 1 0
      api/controllers/console/app/app.py
  2. 2 0
      api/controllers/console/app/completion.py
  3. 2 0
      api/controllers/console/explore/completion.py
  4. 20 0
      api/controllers/console/explore/message.py
  5. 2 0
      api/controllers/console/explore/parameter.py
  6. 2 0
      api/controllers/console/universal_chat/chat.py
  7. 20 0
      api/controllers/console/universal_chat/message.py
  8. 5 0
      api/controllers/console/universal_chat/parameter.py
  9. 1 0
      api/controllers/console/universal_chat/wraps.py
  10. 2 0
      api/controllers/service_api/app/app.py
  11. 4 0
      api/controllers/service_api/app/completion.py
  12. 19 0
      api/controllers/service_api/app/message.py
  13. 2 0
      api/controllers/web/app.py
  14. 4 0
      api/controllers/web/completion.py
  15. 20 0
      api/controllers/web/message.py
  16. 5 0
      api/core/agent/agent/multi_dataset_router_agent.py
  17. 1 4
      api/core/callback_handler/dataset_tool_callback_handler.py
  18. 7 1
      api/core/callback_handler/index_tool_callback_handler.py
  19. 6 4
      api/core/completion.py
  20. 53 3
      api/core/conversation_message_task.py
  21. 1 1
      api/core/index/keyword_table_index/keyword_table_index.py
  22. 19 0
      api/core/index/vector_index/qdrant_vector_index.py
  23. 1 0
      api/core/model_providers/models/entity/message.py
  24. 20 9
      api/core/orchestrator_rule_parser.py
  25. 1 1
      api/core/prompt/generate_prompts/common_chat.json
  26. 1 1
      api/core/prompt/prompts.py
  27. 48 4
      api/core/tool/dataset_retriever_tool.py
  28. 54 0
      api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py
  29. 32 0
      api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py
  30. 44 1
      api/models/model.py
  31. 16 0
      api/services/app_model_config_service.py
  32. 27 4
      api/services/completion_service.py

+ 1 - 0
api/controllers/console/app/app.py

@@ -29,6 +29,7 @@ model_config_fields = {
     'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
     'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
     'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
     'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
     'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
     'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
+    'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
     'more_like_this': fields.Raw(attribute='more_like_this_dict'),
     'more_like_this': fields.Raw(attribute='more_like_this_dict'),
     'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
     'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
     'model': fields.Raw(attribute='model_dict'),
     'model': fields.Raw(attribute='model_dict'),

+ 2 - 0
api/controllers/console/app/completion.py

@@ -42,6 +42,7 @@ class CompletionMessageApi(Resource):
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] != 'blocking'
         streaming = args['response_mode'] != 'blocking'
@@ -115,6 +116,7 @@ class ChatMessageApi(Resource):
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] != 'blocking'
         streaming = args['response_mode'] != 'blocking'

+ 2 - 0
api/controllers/console/explore/completion.py

@@ -33,6 +33,7 @@ class CompletionApi(InstalledAppResource):
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
@@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource):
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'

+ 20 - 0
api/controllers/console/explore/message.py

@@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource):
         'rating': fields.String
         'rating': fields.String
     }
     }
 
 
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
+
     message_fields = {
     message_fields = {
         'id': fields.String,
         'id': fields.String,
         'conversation_id': fields.String,
         'conversation_id': fields.String,
@@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource):
         'query': fields.String,
         'query': fields.String,
         'answer': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
         'created_at': TimestampField
     }
     }
 
 

+ 2 - 0
api/controllers/console/explore/parameter.py

@@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource):
         'suggested_questions': fields.Raw,
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
         'user_input_form': fields.Raw,
     }
     }
@@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
             'user_input_form': app_model_config.user_input_form_list
         }
         }

+ 2 - 0
api/controllers/console/universal_chat/chat.py

@@ -29,9 +29,11 @@ class UniversalChatApi(UniversalChatResource):
         parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
         parser.add_argument('tools', type=list, required=True, location='json')
         parser.add_argument('tools', type=list, required=True, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         app_model_config = app_model.app_model_config
         app_model_config = app_model.app_model_config
+        app_model_config
 
 
         # update app model config
         # update app model config
         args['model_config'] = app_model_config.to_dict()
         args['model_config'] = app_model_config.to_dict()

+ 20 - 0
api/controllers/console/universal_chat/message.py

@@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource):
         'created_at': TimestampField
         'created_at': TimestampField
     }
     }
 
 
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
+
     message_fields = {
     message_fields = {
         'id': fields.String,
         'id': fields.String,
         'conversation_id': fields.String,
         'conversation_id': fields.String,
@@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource):
         'query': fields.String,
         'query': fields.String,
         'answer': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField,
         'created_at': TimestampField,
         'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
         'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
     }
     }

+ 5 - 0
api/controllers/console/universal_chat/parameter.py

@@ -1,4 +1,6 @@
 # -*- coding:utf-8 -*-
 # -*- coding:utf-8 -*-
+import json
+
 from flask_restful import marshal_with, fields
 from flask_restful import marshal_with, fields
 
 
 from controllers.console import api
 from controllers.console import api
@@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource):
         'suggested_questions': fields.Raw,
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
     }
     }
 
 
     @marshal_with(parameters_fields)
     @marshal_with(parameters_fields)
@@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource):
         """Retrieve app parameters."""
         """Retrieve app parameters."""
         app_model = universal_app
         app_model = universal_app
         app_model_config = app_model.app_model_config
         app_model_config = app_model.app_model_config
+        app_model_config.retriever_resource = json.dumps({'enabled': True})
 
 
         return {
         return {
             'opening_statement': app_model_config.opening_statement,
             'opening_statement': app_model_config.opening_statement,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
         }
         }
 
 
 
 

+ 1 - 0
api/controllers/console/universal_chat/wraps.py

@@ -47,6 +47,7 @@ def universal_chat_app_required(view=None):
                     suggested_questions=json.dumps([]),
                     suggested_questions=json.dumps([]),
                     suggested_questions_after_answer=json.dumps({'enabled': True}),
                     suggested_questions_after_answer=json.dumps({'enabled': True}),
                     speech_to_text=json.dumps({'enabled': True}),
                     speech_to_text=json.dumps({'enabled': True}),
+                    retriever_resource=json.dumps({'enabled': True}),
                     more_like_this=None,
                     more_like_this=None,
                     sensitive_word_avoidance=None,
                     sensitive_word_avoidance=None,
                     model=json.dumps({
                     model=json.dumps({

+ 2 - 0
api/controllers/service_api/app/app.py

@@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource):
         'suggested_questions': fields.Raw,
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
         'user_input_form': fields.Raw,
     }
     }
@@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource):
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
             'user_input_form': app_model_config.user_input_form_list
         }
         }

+ 4 - 0
api/controllers/service_api/app/completion.py

@@ -30,6 +30,8 @@ class CompletionApi(AppApiResource):
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('user', type=str, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
@@ -91,6 +93,8 @@ class ChatApi(AppApiResource):
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('user', type=str, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'

+ 19 - 0
api/controllers/service_api/app/message.py

@@ -16,6 +16,24 @@ class MessageListApi(AppApiResource):
     feedback_fields = {
     feedback_fields = {
         'rating': fields.String
         'rating': fields.String
     }
     }
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
 
 
     message_fields = {
     message_fields = {
         'id': fields.String,
         'id': fields.String,
@@ -24,6 +42,7 @@ class MessageListApi(AppApiResource):
         'query': fields.String,
         'query': fields.String,
         'answer': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
         'created_at': TimestampField
     }
     }
 
 

+ 2 - 0
api/controllers/web/app.py

@@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource):
         'suggested_questions': fields.Raw,
         'suggested_questions': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'suggested_questions_after_answer': fields.Raw,
         'speech_to_text': fields.Raw,
         'speech_to_text': fields.Raw,
+        'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
         'user_input_form': fields.Raw,
     }
     }
@@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource):
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions': app_model_config.suggested_questions_list,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
             'speech_to_text': app_model_config.speech_to_text_dict,
+            'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list
             'user_input_form': app_model_config.user_input_form_list
         }
         }

+ 4 - 0
api/controllers/web/completion.py

@@ -31,6 +31,8 @@ class CompletionApi(WebApiResource):
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('query', type=str, location='json', default='')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
+
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'
@@ -88,6 +90,8 @@ class ChatApi(WebApiResource):
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
+        parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
+
         args = parser.parse_args()
         args = parser.parse_args()
 
 
         streaming = args['response_mode'] == 'streaming'
         streaming = args['response_mode'] == 'streaming'

+ 20 - 0
api/controllers/web/message.py

@@ -29,6 +29,25 @@ class MessageListApi(WebApiResource):
         'rating': fields.String
         'rating': fields.String
     }
     }
 
 
+    retriever_resource_fields = {
+        'id': fields.String,
+        'message_id': fields.String,
+        'position': fields.Integer,
+        'dataset_id': fields.String,
+        'dataset_name': fields.String,
+        'document_id': fields.String,
+        'document_name': fields.String,
+        'data_source_type': fields.String,
+        'segment_id': fields.String,
+        'score': fields.Float,
+        'hit_count': fields.Integer,
+        'word_count': fields.Integer,
+        'segment_position': fields.Integer,
+        'index_node_hash': fields.String,
+        'content': fields.String,
+        'created_at': TimestampField
+    }
+
     message_fields = {
     message_fields = {
         'id': fields.String,
         'id': fields.String,
         'conversation_id': fields.String,
         'conversation_id': fields.String,
@@ -36,6 +55,7 @@ class MessageListApi(WebApiResource):
         'query': fields.String,
         'query': fields.String,
         'answer': fields.String,
         'answer': fields.String,
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField
         'created_at': TimestampField
     }
     }
 
 

+ 5 - 0
api/core/agent/agent/multi_dataset_router_agent.py

@@ -1,3 +1,4 @@
+import json
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 from typing import Tuple, List, Any, Union, Sequence, Optional, cast
 
 
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
 from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
             tool = next(iter(self.tools))
             tool = next(iter(self.tools))
             tool = cast(DatasetRetrieverTool, tool)
             tool = cast(DatasetRetrieverTool, tool)
             rst = tool.run(tool_input={'query': kwargs['input']})
             rst = tool.run(tool_input={'query': kwargs['input']})
+            # output = ''
+            # rst_json = json.loads(rst)
+            # for item in rst_json:
+            #     output += f'{item["content"]}\n'
             return AgentFinish(return_values={"output": rst}, log=rst)
             return AgentFinish(return_values={"output": rst}, log=rst)
 
 
         if intermediate_steps:
         if intermediate_steps:

+ 1 - 4
api/core/callback_handler/dataset_tool_callback_handler.py

@@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
         llm_prefix: Optional[str] = None,
         llm_prefix: Optional[str] = None,
         **kwargs: Any,
         **kwargs: Any,
     ) -> None:
     ) -> None:
-        # kwargs={'name': 'Search'}
-        # llm_prefix='Thought:'
-        # observation_prefix='Observation: '
-        # output='53 years'
         pass
         pass
 
 
+
     def on_tool_error(
     def on_tool_error(
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
         self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
     ) -> None:

+ 7 - 1
api/core/callback_handler/index_tool_callback_handler.py

@@ -2,6 +2,7 @@ from typing import List
 
 
 from langchain.schema import Document
 from langchain.schema import Document
 
 
+from core.conversation_message_task import ConversationMessageTask
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import DocumentSegment
 from models.dataset import DocumentSegment
 
 
@@ -9,8 +10,9 @@ from models.dataset import DocumentSegment
 class DatasetIndexToolCallbackHandler:
 class DatasetIndexToolCallbackHandler:
     """Callback handler for dataset tool."""
     """Callback handler for dataset tool."""
 
 
-    def __init__(self, dataset_id: str) -> None:
+    def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
         self.dataset_id = dataset_id
         self.dataset_id = dataset_id
+        self.conversation_message_task = conversation_message_task
 
 
     def on_tool_end(self, documents: List[Document]) -> None:
     def on_tool_end(self, documents: List[Document]) -> None:
         """Handle tool end."""
         """Handle tool end."""
@@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler:
             )
             )
 
 
             db.session.commit()
             db.session.commit()
+
+    def return_retriever_resource_info(self, resource: List):
+        """Handle return_retriever_resource_info."""
+        self.conversation_message_task.on_dataset_query_finish(resource)

+ 6 - 4
api/core/completion.py

@@ -1,3 +1,4 @@
+import json
 import logging
 import logging
 import re
 import re
 from typing import Optional, List, Union, Tuple
 from typing import Optional, List, Union, Tuple
@@ -19,13 +20,15 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
+from models.dataset import DocumentSegment, Dataset, Document
 from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
 from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
 
 
 
 
 class Completion:
 class Completion:
     @classmethod
     @classmethod
     def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
     def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
-                 user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
+                 user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
+                 is_override: bool = False, retriever_from: str = 'dev'):
         """
         """
         errors: ProviderTokenNotInitError
         errors: ProviderTokenNotInitError
         """
         """
@@ -96,7 +99,6 @@ class Completion:
             should_use_agent = agent_executor.should_use_agent(query)
             should_use_agent = agent_executor.should_use_agent(query)
             if should_use_agent:
             if should_use_agent:
                 agent_execute_result = agent_executor.run(query)
                 agent_execute_result = agent_executor.run(query)
-
         # run the final llm
         # run the final llm
         try:
         try:
             cls.run_final_llm(
             cls.run_final_llm(
@@ -118,7 +120,8 @@ class Completion:
             return
             return
 
 
     @classmethod
     @classmethod
-    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
+    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
+                      inputs: dict,
                       agent_execute_result: Optional[AgentExecuteResult],
                       agent_execute_result: Optional[AgentExecuteResult],
                       conversation_message_task: ConversationMessageTask,
                       conversation_message_task: ConversationMessageTask,
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
@@ -150,7 +153,6 @@ class Completion:
             callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
             callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
             fake_response=fake_response
             fake_response=fake_response
         )
         )
-
         return response
         return response
 
 
     @classmethod
     @classmethod

+ 53 - 3
api/core/conversation_message_task.py

@@ -1,6 +1,6 @@
 import decimal
 import decimal
 import json
 import json
-from typing import Optional, Union
+from typing import Optional, Union, List
 
 
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
@@ -15,7 +15,8 @@ from events.message_event import message_was_created
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
 from models.dataset import DatasetQuery
 from models.dataset import DatasetQuery
-from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
+from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
+    MessageChain, DatasetRetrieverResource
 
 
 
 
 class ConversationMessageTask:
 class ConversationMessageTask:
@@ -41,6 +42,8 @@ class ConversationMessageTask:
 
 
         self.message = None
         self.message = None
 
 
+        self.retriever_resource = None
+
         self.model_dict = self.app_model_config.model_dict
         self.model_dict = self.app_model_config.model_dict
         self.provider_name = self.model_dict.get('provider')
         self.provider_name = self.model_dict.get('provider')
         self.model_name = self.model_dict.get('name')
         self.model_name = self.model_dict.get('name')
@@ -157,7 +160,8 @@ class ConversationMessageTask:
         self.message.message_tokens = message_tokens
         self.message.message_tokens = message_tokens
         self.message.message_unit_price = message_unit_price
         self.message.message_unit_price = message_unit_price
         self.message.message_price_unit = message_price_unit
         self.message.message_price_unit = message_price_unit
-        self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
+        self.message.answer = PromptBuilder.process_template(
+            llm_message.completion.strip()) if llm_message.completion else ''
         self.message.answer_tokens = answer_tokens
         self.message.answer_tokens = answer_tokens
         self.message.answer_unit_price = answer_unit_price
         self.message.answer_unit_price = answer_unit_price
         self.message.answer_price_unit = answer_price_unit
         self.message.answer_price_unit = answer_price_unit
@@ -256,7 +260,36 @@ class ConversationMessageTask:
 
 
         db.session.add(dataset_query)
         db.session.add(dataset_query)
 
 
+    def on_dataset_query_finish(self, resource: List):
+        if resource and len(resource) > 0:
+            for item in resource:
+                dataset_retriever_resource = DatasetRetrieverResource(
+                    message_id=self.message.id,
+                    position=item.get('position'),
+                    dataset_id=item.get('dataset_id'),
+                    dataset_name=item.get('dataset_name'),
+                    document_id=item.get('document_id'),
+                    document_name=item.get('document_name'),
+                    data_source_type=item.get('data_source_type'),
+                    segment_id=item.get('segment_id'),
+                    score=item.get('score') if 'score' in item else None,
+                    hit_count=item.get('hit_count') if 'hit_count' else None,
+                    word_count=item.get('word_count') if 'word_count' in item else None,
+                    segment_position=item.get('segment_position') if 'segment_position' in item else None,
+                    index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
+                    content=item.get('content'),
+                    retriever_from=item.get('retriever_from'),
+                    created_by=self.user.id
+                )
+                db.session.add(dataset_retriever_resource)
+                db.session.flush()
+            self.retriever_resource = resource
+
+    def message_end(self):
+        self._pub_handler.pub_message_end(self.retriever_resource)
+
     def end(self):
     def end(self):
+        self._pub_handler.pub_message_end(self.retriever_resource)
         self._pub_handler.pub_end()
         self._pub_handler.pub_end()
 
 
 
 
@@ -350,6 +383,23 @@ class PubHandler:
             self.pub_end()
             self.pub_end()
             raise ConversationTaskStoppedException()
             raise ConversationTaskStoppedException()
 
 
+    def pub_message_end(self, retriever_resource: List):
+        content = {
+            'event': 'message_end',
+            'data': {
+                'task_id': self._task_id,
+                'message_id': self._message.id,
+                'mode': self._conversation.mode,
+                'conversation_id': self._conversation.id
+            }
+        }
+        if retriever_resource:
+            content['data']['retriever_resources'] = retriever_resource
+        redis_client.publish(self._channel, json.dumps(content))
+
+        if self._is_stopped():
+            self.pub_end()
+            raise ConversationTaskStoppedException()
 
 
     def pub_end(self):
     def pub_end(self):
         content = {
         content = {

+ 1 - 1
api/core/index/keyword_table_index/keyword_table_index.py

@@ -74,7 +74,7 @@ class KeywordTableIndex(BaseIndex):
             DocumentSegment.document_id == document_id
             DocumentSegment.document_id == document_id
         ).all()
         ).all()
 
 
-        ids = [segment.id for segment in segments]
+        ids = [segment.index_node_id for segment in segments]
 
 
         keyword_table = self._get_dataset_keyword_table()
         keyword_table = self._get_dataset_keyword_table()
         keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
         keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)

+ 19 - 0
api/core/index/vector_index/qdrant_vector_index.py

@@ -113,6 +113,25 @@ class QdrantVectorIndex(BaseVectorIndex):
             ],
             ],
         ))
         ))
 
 
+    def delete_by_ids(self, ids: list[str]) -> None:
+        if self._is_origin():
+            self.recreate_dataset(self.dataset)
+            return
+
+        vector_store = self._get_vector_store()
+        vector_store = cast(self._get_vector_store_class(), vector_store)
+
+        from qdrant_client.http import models
+        for node_id in ids:
+            vector_store.del_texts(models.Filter(
+                must=[
+                    models.FieldCondition(
+                        key="metadata.doc_id",
+                        match=models.MatchValue(value=node_id),
+                    ),
+                ],
+            ))
+
     def _is_origin(self):
     def _is_origin(self):
         if self.dataset.index_struct_dict:
         if self.dataset.index_struct_dict:
             class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
             class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

+ 1 - 0
api/core/model_providers/models/entity/message.py

@@ -8,6 +8,7 @@ class LLMRunResult(BaseModel):
     content: str
     content: str
     prompt_tokens: int
     prompt_tokens: int
     completion_tokens: int
     completion_tokens: int
+    source: list = None
 
 
 
 
 class MessageType(enum.Enum):
 class MessageType(enum.Enum):

+ 20 - 9
api/core/orchestrator_rule_parser.py

@@ -36,8 +36,8 @@ class OrchestratorRuleParser:
         self.app_model_config = app_model_config
         self.app_model_config = app_model_config
 
 
     def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
     def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
-                          rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
-            -> Optional[AgentExecutor]:
+                          rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
+                          return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
         if not self.app_model_config.agent_mode_dict:
         if not self.app_model_config.agent_mode_dict:
             return None
             return None
 
 
@@ -74,7 +74,7 @@ class OrchestratorRuleParser:
 
 
             # only OpenAI chat model (include Azure) support function call, use ReACT instead
             # only OpenAI chat model (include Azure) support function call, use ReACT instead
             if agent_model_instance.model_mode != ModelMode.CHAT \
             if agent_model_instance.model_mode != ModelMode.CHAT \
-                         or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
+                    or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
                 if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
                 if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
                     planning_strategy = PlanningStrategy.REACT
                     planning_strategy = PlanningStrategy.REACT
                 elif planning_strategy == PlanningStrategy.ROUTER:
                 elif planning_strategy == PlanningStrategy.ROUTER:
@@ -99,7 +99,9 @@ class OrchestratorRuleParser:
                 tool_configs=tool_configs,
                 tool_configs=tool_configs,
                 conversation_message_task=conversation_message_task,
                 conversation_message_task=conversation_message_task,
                 rest_tokens=rest_tokens,
                 rest_tokens=rest_tokens,
-                callbacks=[agent_callback, DifyStdOutCallbackHandler()]
+                callbacks=[agent_callback, DifyStdOutCallbackHandler()],
+                return_resource=return_resource,
+                retriever_from=retriever_from
             )
             )
 
 
             if len(tools) == 0:
             if len(tools) == 0:
@@ -145,8 +147,10 @@ class OrchestratorRuleParser:
 
 
         return None
         return None
 
 
-    def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask,
-                 rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
+    def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
+                 conversation_message_task: ConversationMessageTask,
+                 rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
+                 retriever_from: str = 'dev') -> list[BaseTool]:
         """
         """
         Convert app agent tool configs to tools
         Convert app agent tool configs to tools
 
 
@@ -155,6 +159,8 @@ class OrchestratorRuleParser:
         :param tool_configs: app agent tool configs
         :param tool_configs: app agent tool configs
         :param conversation_message_task:
         :param conversation_message_task:
         :param callbacks:
         :param callbacks:
+        :param return_resource:
+        :param retriever_from:
         :return:
         :return:
         """
         """
         tools = []
         tools = []
@@ -166,7 +172,7 @@ class OrchestratorRuleParser:
 
 
             tool = None
             tool = None
             if tool_type == "dataset":
             if tool_type == "dataset":
-                tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
+                tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
             elif tool_type == "web_reader":
             elif tool_type == "web_reader":
                 tool = self.to_web_reader_tool(agent_model_instance)
                 tool = self.to_web_reader_tool(agent_model_instance)
             elif tool_type == "google_search":
             elif tool_type == "google_search":
@@ -183,13 +189,15 @@ class OrchestratorRuleParser:
         return tools
         return tools
 
 
     def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
     def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
-                                  rest_tokens: int) \
+                                  rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
             -> Optional[BaseTool]:
             -> Optional[BaseTool]:
         """
         """
         A dataset tool is a tool that can be used to retrieve information from a dataset
         A dataset tool is a tool that can be used to retrieve information from a dataset
         :param rest_tokens:
         :param rest_tokens:
         :param tool_config:
         :param tool_config:
         :param conversation_message_task:
         :param conversation_message_task:
+        :param return_resource:
+        :param retriever_from:
         :return:
         :return:
         """
         """
         # get dataset from dataset id
         # get dataset from dataset id
@@ -208,7 +216,10 @@ class OrchestratorRuleParser:
         tool = DatasetRetrieverTool.from_dataset(
         tool = DatasetRetrieverTool.from_dataset(
             dataset=dataset,
             dataset=dataset,
             k=k,
             k=k,
-            callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
+            callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
+            conversation_message_task=conversation_message_task,
+            return_resource=return_resource,
+            retriever_from=retriever_from
         )
         )
 
 
         return tool
         return tool

+ 1 - 1
api/core/prompt/generate_prompts/common_chat.json

@@ -10,4 +10,4 @@
   ],
   ],
   "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ",
   "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ",
   "stops": ["\nHuman:", "</histories>"]
   "stops": ["\nHuman:", "</histories>"]
-}
+}

+ 1 - 1
api/core/prompt/prompts.py

@@ -105,7 +105,7 @@ GENERATOR_QA_PROMPT = (
     'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
     'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
     'Step 4: Generate 20 questions and answers based on these key information and concepts.'
     'Step 4: Generate 20 questions and answers based on these key information and concepts.'
     'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
     'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
-    "Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
+    "Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
 )
 )
 
 
 RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
 RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \

+ 48 - 4
api/core/tool/dataset_retriever_tool.py

@@ -1,3 +1,4 @@
+import json
 from typing import Type
 from typing import Type
 
 
 from flask import current_app
 from flask import current_app
@@ -5,13 +6,14 @@ from langchain.tools import BaseTool
 from pydantic import Field, BaseModel
 from pydantic import Field, BaseModel
 
 
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
 from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
+from core.conversation_message_task import ConversationMessageTask
 from core.embedding.cached_embedding import CacheEmbedding
 from core.embedding.cached_embedding import CacheEmbedding
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
 from core.index.vector_index.vector_index import VectorIndex
 from core.index.vector_index.vector_index import VectorIndex
 from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
 from extensions.ext_database import db
-from models.dataset import Dataset, DocumentSegment
+from models.dataset import Dataset, DocumentSegment, Document
 
 
 
 
 class DatasetRetrieverToolInput(BaseModel):
 class DatasetRetrieverToolInput(BaseModel):
@@ -27,6 +29,10 @@ class DatasetRetrieverTool(BaseTool):
     tenant_id: str
     tenant_id: str
     dataset_id: str
     dataset_id: str
     k: int = 3
     k: int = 3
+    conversation_message_task: ConversationMessageTask
+    return_resource: str
+    retriever_from: str
+
 
 
     @classmethod
     @classmethod
     def from_dataset(cls, dataset: Dataset, **kwargs):
     def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -86,7 +92,7 @@ class DatasetRetrieverTool(BaseTool):
             if self.k > 0:
             if self.k > 0:
                 documents = vector_index.search(
                 documents = vector_index.search(
                     query,
                     query,
-                    search_type='similarity',
+                    search_type='similarity_score_threshold',
                     search_kwargs={
                     search_kwargs={
                         'k': self.k
                         'k': self.k
                     }
                     }
@@ -94,8 +100,12 @@ class DatasetRetrieverTool(BaseTool):
             else:
             else:
                 documents = []
                 documents = []
 
 
-            hit_callback = DatasetIndexToolCallbackHandler(dataset.id)
+            hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
             hit_callback.on_tool_end(documents)
             hit_callback.on_tool_end(documents)
+            document_score_list = {}
+            if dataset.indexing_technique != "economy":
+                for item in documents:
+                    document_score_list[item.metadata['doc_id']] = item.metadata['score']
             document_context_list = []
             document_context_list = []
             index_node_ids = [document.metadata['doc_id'] for document in documents]
             index_node_ids = [document.metadata['doc_id'] for document in documents]
             segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
             segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
@@ -112,9 +122,43 @@ class DatasetRetrieverTool(BaseTool):
                                                                                            float('inf')))
                                                                                            float('inf')))
                 for segment in sorted_segments:
                 for segment in sorted_segments:
                     if segment.answer:
                     if segment.answer:
-                        document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}')
+                        document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
                     else:
                     else:
                         document_context_list.append(segment.content)
                         document_context_list.append(segment.content)
+                if self.return_resource:
+                    context_list = []
+                    resource_number = 1
+                    for segment in sorted_segments:
+                        context = {}
+                        document = Document.query.filter(Document.id == segment.document_id,
+                                                         Document.enabled == True,
+                                                         Document.archived == False,
+                                                         ).first()
+                        if dataset and document:
+                            source = {
+                                'position': resource_number,
+                                'dataset_id': dataset.id,
+                                'dataset_name': dataset.name,
+                                'document_id': document.id,
+                                'document_name': document.name,
+                                'data_source_type': document.data_source_type,
+                                'segment_id': segment.id,
+                                'retriever_from': self.retriever_from
+                            }
+                            if dataset.indexing_technique != "economy":
+                                source['score'] = document_score_list.get(segment.index_node_id)
+                            if self.retriever_from == 'dev':
+                                source['hit_count'] = segment.hit_count
+                                source['word_count'] = segment.word_count
+                                source['segment_position'] = segment.position
+                                source['index_node_hash'] = segment.index_node_hash
+                            if segment.answer:
+                                source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
+                            else:
+                                source['content'] = segment.content
+                            context_list.append(source)
+                        resource_number += 1
+                    hit_callback.return_retriever_resource_info(context_list)
 
 
             return str("\n".join(document_context_list))
             return str("\n".join(document_context_list))
 
 

+ 54 - 0
api/migrations/versions/6dcb43972bdc_add_dataset_retriever_resource.py

@@ -0,0 +1,54 @@
+"""add_dataset_retriever_resource
+
+Revision ID: 6dcb43972bdc
+Revises: 4bcffcd64aa4
+Create Date: 2023-09-06 16:51:27.385844
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '6dcb43972bdc'
+down_revision = '4bcffcd64aa4'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('dataset_retriever_resources',
+    sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('message_id', postgresql.UUID(), nullable=False),
+    sa.Column('position', sa.Integer(), nullable=False),
+    sa.Column('dataset_id', postgresql.UUID(), nullable=False),
+    sa.Column('dataset_name', sa.Text(), nullable=False),
+    sa.Column('document_id', postgresql.UUID(), nullable=False),
+    sa.Column('document_name', sa.Text(), nullable=False),
+    sa.Column('data_source_type', sa.Text(), nullable=False),
+    sa.Column('segment_id', postgresql.UUID(), nullable=False),
+    sa.Column('score', sa.Float(), nullable=True),
+    sa.Column('content', sa.Text(), nullable=False),
+    sa.Column('hit_count', sa.Integer(), nullable=True),
+    sa.Column('word_count', sa.Integer(), nullable=True),
+    sa.Column('segment_position', sa.Integer(), nullable=True),
+    sa.Column('index_node_hash', sa.Text(), nullable=True),
+    sa.Column('retriever_from', sa.Text(), nullable=False),
+    sa.Column('created_by', postgresql.UUID(), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey')
+    )
+    with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+        batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
+        batch_op.drop_index('dataset_retriever_resource_message_id_idx')
+
+    op.drop_table('dataset_retriever_resources')
+    # ### end Alembic commands ###

+ 32 - 0
api/migrations/versions/77e83833755c_add_app_config_retriever_resource.py

@@ -0,0 +1,32 @@
+"""add_app_config_retriever_resource
+
+Revision ID: 77e83833755c
+Revises: 6dcb43972bdc
+Create Date: 2023-09-06 17:26:40.311927
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '77e83833755c'
+down_revision = '6dcb43972bdc'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.drop_column('retriever_resource')
+
+    # ### end Alembic commands ###

+ 44 - 1
api/models/model.py

@@ -1,4 +1,5 @@
 import json
 import json
+from json import JSONDecodeError
 
 
 from flask import current_app, request
 from flask import current_app, request
 from flask_login import UserMixin
 from flask_login import UserMixin
@@ -90,6 +91,7 @@ class AppModelConfig(db.Model):
     pre_prompt = db.Column(db.Text)
     pre_prompt = db.Column(db.Text)
     agent_mode = db.Column(db.Text)
     agent_mode = db.Column(db.Text)
     sensitive_word_avoidance = db.Column(db.Text)
     sensitive_word_avoidance = db.Column(db.Text)
+    retriever_resource = db.Column(db.Text)
 
 
     @property
     @property
     def app(self):
     def app(self):
@@ -114,6 +116,11 @@ class AppModelConfig(db.Model):
         return json.loads(self.speech_to_text) if self.speech_to_text \
         return json.loads(self.speech_to_text) if self.speech_to_text \
             else {"enabled": False}
             else {"enabled": False}
 
 
+    @property
+    def retriever_resource_dict(self) -> dict:
+        return json.loads(self.retriever_resource) if self.retriever_resource \
+            else {"enabled": False}
+
     @property
     @property
     def more_like_this_dict(self) -> dict:
     def more_like_this_dict(self) -> dict:
         return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
         return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@@ -140,6 +147,7 @@ class AppModelConfig(db.Model):
             "suggested_questions": self.suggested_questions_list,
             "suggested_questions": self.suggested_questions_list,
             "suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
             "suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
             "speech_to_text": self.speech_to_text_dict,
             "speech_to_text": self.speech_to_text_dict,
+            "retriever_resource": self.retriever_resource,
             "more_like_this": self.more_like_this_dict,
             "more_like_this": self.more_like_this_dict,
             "sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
             "sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
             "model": self.model_dict,
             "model": self.model_dict,
@@ -164,7 +172,8 @@ class AppModelConfig(db.Model):
         self.user_input_form = json.dumps(model_config['user_input_form'])
         self.user_input_form = json.dumps(model_config['user_input_form'])
         self.pre_prompt = model_config['pre_prompt']
         self.pre_prompt = model_config['pre_prompt']
         self.agent_mode = json.dumps(model_config['agent_mode'])
         self.agent_mode = json.dumps(model_config['agent_mode'])
-
+        self.retriever_resource = json.dumps(model_config['retriever_resource']) \
+            if model_config.get('retriever_resource') else None
         return self
         return self
 
 
     def copy(self):
     def copy(self):
@@ -318,6 +327,7 @@ class Conversation(db.Model):
             model_config['suggested_questions'] = app_model_config.suggested_questions_list
             model_config['suggested_questions'] = app_model_config.suggested_questions_list
             model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
             model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
             model_config['speech_to_text'] = app_model_config.speech_to_text_dict
             model_config['speech_to_text'] = app_model_config.speech_to_text_dict
+            model_config['retriever_resource'] = app_model_config.retriever_resource_dict
             model_config['more_like_this'] = app_model_config.more_like_this_dict
             model_config['more_like_this'] = app_model_config.more_like_this_dict
             model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
             model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
             model_config['user_input_form'] = app_model_config.user_input_form_list
             model_config['user_input_form'] = app_model_config.user_input_form_list
@@ -476,6 +486,11 @@ class Message(db.Model):
         return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
         return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
             .order_by(MessageAgentThought.position.asc()).all()
             .order_by(MessageAgentThought.position.asc()).all()
 
 
+    @property
+    def retriever_resources(self):
+        return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
+            .order_by(DatasetRetrieverResource.position.asc()).all()
+
 
 
 class MessageFeedback(db.Model):
 class MessageFeedback(db.Model):
     __tablename__ = 'message_feedbacks'
     __tablename__ = 'message_feedbacks'
@@ -719,3 +734,31 @@ class MessageAgentThought(db.Model):
     created_by_role = db.Column(db.String, nullable=False)
     created_by_role = db.Column(db.String, nullable=False)
     created_by = db.Column(UUID, nullable=False)
     created_by = db.Column(UUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+
+
+class DatasetRetrieverResource(db.Model):
+    __tablename__ = 'dataset_retriever_resources'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'),
+        db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
+    )
+
+    id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(UUID, nullable=False)
+    position = db.Column(db.Integer, nullable=False)
+    dataset_id = db.Column(UUID, nullable=False)
+    dataset_name = db.Column(db.Text, nullable=False)
+    document_id = db.Column(UUID, nullable=False)
+    document_name = db.Column(db.Text, nullable=False)
+    data_source_type = db.Column(db.Text, nullable=False)
+    segment_id = db.Column(UUID, nullable=False)
+    score = db.Column(db.Float, nullable=True)
+    content = db.Column(db.Text, nullable=False)
+    hit_count = db.Column(db.Integer, nullable=True)
+    word_count = db.Column(db.Integer, nullable=True)
+    segment_position = db.Column(db.Integer, nullable=True)
+    index_node_hash = db.Column(db.Text, nullable=True)
+    retriever_from = db.Column(db.Text, nullable=False)
+    created_by = db.Column(UUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+

+ 16 - 0
api/services/app_model_config_service.py

@@ -130,6 +130,21 @@ class AppModelConfigService:
         if not isinstance(config["speech_to_text"]["enabled"], bool):
         if not isinstance(config["speech_to_text"]["enabled"], bool):
             raise ValueError("enabled in speech_to_text must be of boolean type")
             raise ValueError("enabled in speech_to_text must be of boolean type")
 
 
+        # return retriever resource
+        if 'retriever_resource' not in config or not config["retriever_resource"]:
+            config["retriever_resource"] = {
+                "enabled": False
+            }
+
+        if not isinstance(config["retriever_resource"], dict):
+            raise ValueError("retriever_resource must be of dict type")
+
+        if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]:
+            config["retriever_resource"]["enabled"] = False
+
+        if not isinstance(config["retriever_resource"]["enabled"], bool):
+            raise ValueError("enabled in speech_to_text must be of boolean type")
+
         # more_like_this
         # more_like_this
         if 'more_like_this' not in config or not config["more_like_this"]:
         if 'more_like_this' not in config or not config["more_like_this"]:
             config["more_like_this"] = {
             config["more_like_this"] = {
@@ -327,6 +342,7 @@ class AppModelConfigService:
             "suggested_questions": config["suggested_questions"],
             "suggested_questions": config["suggested_questions"],
             "suggested_questions_after_answer": config["suggested_questions_after_answer"],
             "suggested_questions_after_answer": config["suggested_questions_after_answer"],
             "speech_to_text": config["speech_to_text"],
             "speech_to_text": config["speech_to_text"],
+            "retriever_resource": config["retriever_resource"],
             "more_like_this": config["more_like_this"],
             "more_like_this": config["more_like_this"],
             "sensitive_word_avoidance": config["sensitive_word_avoidance"],
             "sensitive_word_avoidance": config["sensitive_word_avoidance"],
             "model": {
             "model": {

+ 27 - 4
api/services/completion_service.py

@@ -11,7 +11,8 @@ from sqlalchemy import and_
 
 
 from core.completion import Completion
 from core.completion import Completion
 from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
 from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
-from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
+from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
+    LLMRateLimitError, \
     LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
     LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
@@ -95,6 +96,7 @@ class CompletionService:
 
 
                 app_model_config_model = app_model_config.model_dict
                 app_model_config_model = app_model_config.model_dict
                 app_model_config_model['completion_params'] = completion_params
                 app_model_config_model['completion_params'] = completion_params
+                app_model_config.retriever_resource = json.dumps({'enabled': True})
 
 
                 app_model_config = app_model_config.copy()
                 app_model_config = app_model_config.copy()
                 app_model_config.model = json.dumps(app_model_config_model)
                 app_model_config.model = json.dumps(app_model_config_model)
@@ -145,7 +147,8 @@ class CompletionService:
             'user': user,
             'user': user,
             'conversation': conversation,
             'conversation': conversation,
             'streaming': streaming,
             'streaming': streaming,
-            'is_model_config_override': is_model_config_override
+            'is_model_config_override': is_model_config_override,
+            'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
         })
         })
 
 
         generate_worker_thread.start()
         generate_worker_thread.start()
@@ -169,7 +172,8 @@ class CompletionService:
     @classmethod
     @classmethod
     def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
     def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
                         query: str, inputs: dict, user: Union[Account, EndUser],
                         query: str, inputs: dict, user: Union[Account, EndUser],
-                        conversation: Conversation, streaming: bool, is_model_config_override: bool):
+                        conversation: Conversation, streaming: bool, is_model_config_override: bool,
+                        retriever_from: str = 'dev'):
         with flask_app.app_context():
         with flask_app.app_context():
             try:
             try:
                 if conversation:
                 if conversation:
@@ -188,6 +192,7 @@ class CompletionService:
                     conversation=conversation,
                     conversation=conversation,
                     streaming=streaming,
                     streaming=streaming,
                     is_override=is_model_config_override,
                     is_override=is_model_config_override,
+                    retriever_from=retriever_from
                 )
                 )
             except ConversationTaskStoppedException:
             except ConversationTaskStoppedException:
                 pass
                 pass
@@ -400,7 +405,11 @@ class CompletionService:
                             elif event == 'chain':
                             elif event == 'chain':
                                 yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
                                 yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
                             elif event == 'agent_thought':
                             elif event == 'agent_thought':
-                                yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
+                                yield "data: " + json.dumps(
+                                    cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
+                            elif event == 'message_end':
+                                yield "data: " + json.dumps(
+                                    cls.get_message_end_data(result.get('data'))) + "\n\n"
                             elif event == 'ping':
                             elif event == 'ping':
                                 yield "event: ping\n\n"
                                 yield "event: ping\n\n"
                             else:
                             else:
@@ -432,6 +441,20 @@ class CompletionService:
 
 
         return response_data
         return response_data
 
 
+    @classmethod
+    def get_message_end_data(cls, data: dict):
+        response_data = {
+            'event': 'message_end',
+            'task_id': data.get('task_id'),
+            'id': data.get('message_id')
+        }
+        if 'retriever_resources' in data:
+            response_data['retriever_resources'] = data.get('retriever_resources')
+        if data.get('mode') == 'chat':
+            response_data['conversation_id'] = data.get('conversation_id')
+
+        return response_data
+
     @classmethod
     @classmethod
     def get_chain_response_data(cls, data: dict):
     def get_chain_response_data(cls, data: dict):
         response_data = {
         response_data = {