Преглед изворни кода

Feat/dataset service api (#1245)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Jyong пре 1 година
родитељ
комит
46154c6705
43 измењених фајлова са 1632 додато и 902 уклоњено
  1. 1 0
      api/controllers/console/apikey.py
  2. 2 103
      api/controllers/console/app/app.py
  3. 4 153
      api/controllers/console/app/conversation.py
  4. 1 38
      api/controllers/console/app/message.py
  5. 1 16
      api/controllers/console/app/site.py
  6. 1 53
      api/controllers/console/datasets/data_source.py
  7. 96 68
      api/controllers/console/datasets/datasets.py
  8. 4 97
      api/controllers/console/datasets/datasets_document.py
  9. 2 31
      api/controllers/console/datasets/datasets_segments.py
  10. 12 86
      api/controllers/console/datasets/file.py
  11. 2 40
      api/controllers/console/datasets/hit_testing.py
  12. 2 16
      api/controllers/console/explore/conversation.py
  13. 1 22
      api/controllers/console/explore/installed_app.py
  14. 1 39
      api/controllers/console/explore/message.py
  15. 4 18
      api/controllers/console/universal_chat/conversation.py
  16. 1 1
      api/controllers/service_api/__init__.py
  17. 3 17
      api/controllers/service_api/app/conversation.py
  18. 84 0
      api/controllers/service_api/dataset/dataset.py
  19. 322 68
      api/controllers/service_api/dataset/document.py
  20. 61 8
      api/controllers/service_api/dataset/error.py
  21. 59 0
      api/controllers/service_api/dataset/segment.py
  22. 22 7
      api/controllers/service_api/wraps.py
  23. 2 16
      api/controllers/web/conversation.py
  24. 1 0
      api/core/data_loader/loader/notion.py
  25. 17 0
      api/core/index/keyword_table_index/keyword_table_index.py
  26. 0 0
      api/fields/__init__.py
  27. 138 0
      api/fields/app_fields.py
  28. 182 0
      api/fields/conversation_fields.py
  29. 65 0
      api/fields/data_source_fields.py
  30. 43 0
      api/fields/dataset_fields.py
  31. 76 0
      api/fields/document_fields.py
  32. 18 0
      api/fields/file_fields.py
  33. 41 0
      api/fields/hit_testing_fields.py
  34. 25 0
      api/fields/installed_app_fields.py
  35. 43 0
      api/fields/message_fields.py
  36. 32 0
      api/fields/segment_fields.py
  37. 36 0
      api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py
  38. 3 2
      api/models/model.py
  39. 66 2
      api/services/dataset_service.py
  40. 1 1
      api/services/errors/__init__.py
  41. 8 0
      api/services/errors/file.py
  42. 123 0
      api/services/file_service.py
  43. 26 0
      api/services/vector_service.py

+ 1 - 0
api/controllers/console/apikey.py

@@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource):
         key = ApiToken.generate_api_key(self.token_prefix, 24)
         api_token = ApiToken()
         setattr(api_token, self.resource_id_field, resource_id)
+        api_token.tenant_id = current_user.current_tenant_id
         api_token.token = key
         api_token.type = self.resource_type
         db.session.add(api_token)

+ 2 - 103
api/controllers/console/app/app.py

@@ -19,41 +19,13 @@ from core.model_providers.model_factory import ModelFactory
 from core.model_providers.model_provider_factory import ModelProviderFactory
 from core.model_providers.models.entity.model_params import ModelType
 from events.app_event import app_was_created, app_was_deleted
+from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
+    app_detail_fields_with_site
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from models.model import App, AppModelConfig, Site
 from services.app_model_config_service import AppModelConfigService
 
-model_config_fields = {
-    'opening_statement': fields.String,
-    'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
-    'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
-    'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
-    'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
-    'more_like_this': fields.Raw(attribute='more_like_this_dict'),
-    'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
-    'model': fields.Raw(attribute='model_dict'),
-    'user_input_form': fields.Raw(attribute='user_input_form_list'),
-    'dataset_query_variable': fields.String,
-    'pre_prompt': fields.String,
-    'agent_mode': fields.Raw(attribute='agent_mode_dict'),
-}
-
-app_detail_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'mode': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'enable_site': fields.Boolean,
-    'enable_api': fields.Boolean,
-    'api_rpm': fields.Integer,
-    'api_rph': fields.Integer,
-    'is_demo': fields.Boolean,
-    'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
-    'created_at': TimestampField
-}
-
 
 def _get_app(app_id, tenant_id):
     app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
@@ -63,35 +35,6 @@ def _get_app(app_id, tenant_id):
 
 
 class AppListApi(Resource):
-    prompt_config_fields = {
-        'prompt_template': fields.String,
-    }
-
-    model_config_partial_fields = {
-        'model': fields.Raw(attribute='model_dict'),
-        'pre_prompt': fields.String,
-    }
-
-    app_partial_fields = {
-        'id': fields.String,
-        'name': fields.String,
-        'mode': fields.String,
-        'icon': fields.String,
-        'icon_background': fields.String,
-        'enable_site': fields.Boolean,
-        'enable_api': fields.Boolean,
-        'is_demo': fields.Boolean,
-        'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
-        'created_at': TimestampField
-    }
-
-    app_pagination_fields = {
-        'page': fields.Integer,
-        'limit': fields.Integer(attribute='per_page'),
-        'total': fields.Integer,
-        'has_more': fields.Boolean(attribute='has_next'),
-        'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
-    }
 
     @setup_required
     @login_required
@@ -238,18 +181,6 @@ class AppListApi(Resource):
 
 
 class AppTemplateApi(Resource):
-    template_fields = {
-        'name': fields.String,
-        'icon': fields.String,
-        'icon_background': fields.String,
-        'description': fields.String,
-        'mode': fields.String,
-        'model_config': fields.Nested(model_config_fields),
-    }
-
-    template_list_fields = {
-        'data': fields.List(fields.Nested(template_fields)),
-    }
 
     @setup_required
     @login_required
@@ -268,38 +199,6 @@ class AppTemplateApi(Resource):
 
 
 class AppApi(Resource):
-    site_fields = {
-        'access_token': fields.String(attribute='code'),
-        'code': fields.String,
-        'title': fields.String,
-        'icon': fields.String,
-        'icon_background': fields.String,
-        'description': fields.String,
-        'default_language': fields.String,
-        'customize_domain': fields.String,
-        'copyright': fields.String,
-        'privacy_policy': fields.String,
-        'customize_token_strategy': fields.String,
-        'prompt_public': fields.Boolean,
-        'app_base_url': fields.String,
-    }
-
-    app_detail_fields_with_site = {
-        'id': fields.String,
-        'name': fields.String,
-        'mode': fields.String,
-        'icon': fields.String,
-        'icon_background': fields.String,
-        'enable_site': fields.Boolean,
-        'enable_api': fields.Boolean,
-        'api_rpm': fields.Integer,
-        'api_rph': fields.Integer,
-        'is_demo': fields.Boolean,
-        'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
-        'site': fields.Nested(site_fields),
-        'api_base_url': fields.String,
-        'created_at': TimestampField
-    }
 
     @setup_required
     @login_required

+ 4 - 153
api/controllers/console/app/conversation.py

@@ -13,107 +13,14 @@ from controllers.console import api
 from controllers.console.app import _get_app
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
+from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
+    conversation_message_detail_fields, conversation_with_summary_pagination_fields
 from libs.helper import TimestampField, datetime_string, uuid_value
 from extensions.ext_database import db
 from models.model import Message, MessageAnnotation, Conversation
 
-account_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'email': fields.String
-}
-
-feedback_fields = {
-    'rating': fields.String,
-    'content': fields.String,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account': fields.Nested(account_fields, allow_null=True),
-}
-
-annotation_fields = {
-    'content': fields.String,
-    'account': fields.Nested(account_fields, allow_null=True),
-    'created_at': TimestampField
-}
-
-message_detail_fields = {
-    'id': fields.String,
-    'conversation_id': fields.String,
-    'inputs': fields.Raw,
-    'query': fields.String,
-    'message': fields.Raw,
-    'message_tokens': fields.Integer,
-    'answer': fields.String,
-    'answer_tokens': fields.Integer,
-    'provider_response_latency': fields.Float,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account_id': fields.String,
-    'feedbacks': fields.List(fields.Nested(feedback_fields)),
-    'annotation': fields.Nested(annotation_fields, allow_null=True),
-    'created_at': TimestampField
-}
-
-feedback_stat_fields = {
-    'like': fields.Integer,
-    'dislike': fields.Integer
-}
-
-model_config_fields = {
-    'opening_statement': fields.String,
-    'suggested_questions': fields.Raw,
-    'model': fields.Raw,
-    'user_input_form': fields.Raw,
-    'pre_prompt': fields.String,
-    'agent_mode': fields.Raw,
-}
-
 
 class CompletionConversationApi(Resource):
-    class MessageTextField(fields.Raw):
-        def format(self, value):
-            return value[0]['text'] if value else ''
-
-    simple_configs_fields = {
-        'prompt_template': fields.String,
-    }
-
-    simple_model_config_fields = {
-        'model': fields.Raw(attribute='model_dict'),
-        'pre_prompt': fields.String,
-    }
-
-    simple_message_detail_fields = {
-        'inputs': fields.Raw,
-        'query': fields.String,
-        'message': MessageTextField,
-        'answer': fields.String,
-    }
-
-    conversation_fields = {
-        'id': fields.String,
-        'status': fields.String,
-        'from_source': fields.String,
-        'from_end_user_id': fields.String,
-        'from_end_user_session_id': fields.String(),
-        'from_account_id': fields.String,
-        'read_at': TimestampField,
-        'created_at': TimestampField,
-        'annotation': fields.Nested(annotation_fields, allow_null=True),
-        'model_config': fields.Nested(simple_model_config_fields),
-        'user_feedback_stats': fields.Nested(feedback_stat_fields),
-        'admin_feedback_stats': fields.Nested(feedback_stat_fields),
-        'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
-    }
-
-    conversation_pagination_fields = {
-        'page': fields.Integer,
-        'limit': fields.Integer(attribute='per_page'),
-        'total': fields.Integer,
-        'has_more': fields.Boolean(attribute='has_next'),
-        'data': fields.List(fields.Nested(conversation_fields), attribute='items')
-    }
 
     @setup_required
     @login_required
@@ -191,21 +98,11 @@ class CompletionConversationApi(Resource):
 
 
 class CompletionConversationDetailApi(Resource):
-    conversation_detail_fields = {
-        'id': fields.String,
-        'status': fields.String,
-        'from_source': fields.String,
-        'from_end_user_id': fields.String,
-        'from_account_id': fields.String,
-        'created_at': TimestampField,
-        'model_config': fields.Nested(model_config_fields),
-        'message': fields.Nested(message_detail_fields, attribute='first_message'),
-    }
 
     @setup_required
     @login_required
     @account_initialization_required
-    @marshal_with(conversation_detail_fields)
+    @marshal_with(conversation_message_detail_fields)
     def get(self, app_id, conversation_id):
         app_id = str(app_id)
         conversation_id = str(conversation_id)
@@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
 
 
 class ChatConversationApi(Resource):
-    simple_configs_fields = {
-        'prompt_template': fields.String,
-    }
-
-    simple_model_config_fields = {
-        'model': fields.Raw(attribute='model_dict'),
-        'pre_prompt': fields.String,
-    }
-
-    conversation_fields = {
-        'id': fields.String,
-        'status': fields.String,
-        'from_source': fields.String,
-        'from_end_user_id': fields.String,
-        'from_end_user_session_id': fields.String,
-        'from_account_id': fields.String,
-        'summary': fields.String(attribute='summary_or_query'),
-        'read_at': TimestampField,
-        'created_at': TimestampField,
-        'annotated': fields.Boolean,
-        'model_config': fields.Nested(simple_model_config_fields),
-        'message_count': fields.Integer,
-        'user_feedback_stats': fields.Nested(feedback_stat_fields),
-        'admin_feedback_stats': fields.Nested(feedback_stat_fields)
-    }
-
-    conversation_pagination_fields = {
-        'page': fields.Integer,
-        'limit': fields.Integer(attribute='per_page'),
-        'total': fields.Integer,
-        'has_more': fields.Boolean(attribute='has_next'),
-        'data': fields.List(fields.Nested(conversation_fields), attribute='items')
-    }
 
     @setup_required
     @login_required
     @account_initialization_required
-    @marshal_with(conversation_pagination_fields)
+    @marshal_with(conversation_with_summary_pagination_fields)
     def get(self, app_id):
         app_id = str(app_id)
 
@@ -356,19 +220,6 @@ class ChatConversationApi(Resource):
 
 
 class ChatConversationDetailApi(Resource):
-    conversation_detail_fields = {
-        'id': fields.String,
-        'status': fields.String,
-        'from_source': fields.String,
-        'from_end_user_id': fields.String,
-        'from_account_id': fields.String,
-        'created_at': TimestampField,
-        'annotated': fields.Boolean,
-        'model_config': fields.Nested(model_config_fields),
-        'message_count': fields.Integer,
-        'user_feedback_stats': fields.Nested(feedback_stat_fields),
-        'admin_feedback_stats': fields.Nested(feedback_stat_fields)
-    }
 
     @setup_required
     @login_required

+ 1 - 38
api/controllers/console/app/message.py

@@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required
 from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
 from core.login.login import login_required
+from fields.conversation_fields import message_detail_fields
 from libs.helper import uuid_value, TimestampField
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
 from extensions.ext_database import db
@@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
 from services.errors.message import MessageNotExistsError
 from services.message_service import MessageService
 
-account_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'email': fields.String
-}
-
-feedback_fields = {
-    'rating': fields.String,
-    'content': fields.String,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account': fields.Nested(account_fields, allow_null=True),
-}
-
-annotation_fields = {
-    'content': fields.String,
-    'account': fields.Nested(account_fields, allow_null=True),
-    'created_at': TimestampField
-}
-
-message_detail_fields = {
-    'id': fields.String,
-    'conversation_id': fields.String,
-    'inputs': fields.Raw,
-    'query': fields.String,
-    'message': fields.Raw,
-    'message_tokens': fields.Integer,
-    'answer': fields.String,
-    'answer_tokens': fields.Integer,
-    'provider_response_latency': fields.Float,
-    'from_source': fields.String,
-    'from_end_user_id': fields.String,
-    'from_account_id': fields.String,
-    'feedbacks': fields.List(fields.Nested(feedback_fields)),
-    'annotation': fields.Nested(annotation_fields, allow_null=True),
-    'created_at': TimestampField
-}
-
 
 class ChatMessageListApi(Resource):
     message_infinite_scroll_pagination_fields = {

+ 1 - 16
api/controllers/console/app/site.py

@@ -8,26 +8,11 @@ from controllers.console import api
 from controllers.console.app import _get_app
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
+from fields.app_fields import app_site_fields
 from libs.helper import supported_language
 from extensions.ext_database import db
 from models.model import Site
 
-app_site_fields = {
-    'app_id': fields.String,
-    'access_token': fields.String(attribute='code'),
-    'code': fields.String,
-    'title': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String,
-    'description': fields.String,
-    'default_language': fields.String,
-    'customize_domain': fields.String,
-    'copyright': fields.String,
-    'privacy_policy': fields.String,
-    'customize_token_strategy': fields.String,
-    'prompt_public': fields.Boolean
-}
-
 
 def parse_app_site_args():
     parser = reqparse.RequestParser()

+ 1 - 53
api/controllers/console/datasets/data_source.py

@@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required
 from core.data_loader.loader.notion import NotionLoader
 from core.indexing_runner import IndexingRunner
 from extensions.ext_database import db
+from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
 from libs.helper import TimestampField
 from models.dataset import Document
 from models.source import DataSourceBinding
@@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30)
 
 
 class DataSourceApi(Resource):
-    integrate_icon_fields = {
-        'type': fields.String,
-        'url': fields.String,
-        'emoji': fields.String
-    }
-    integrate_page_fields = {
-        'page_name': fields.String,
-        'page_id': fields.String,
-        'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
-        'parent_id': fields.String,
-        'type': fields.String
-    }
-    integrate_workspace_fields = {
-        'workspace_name': fields.String,
-        'workspace_id': fields.String,
-        'workspace_icon': fields.String,
-        'pages': fields.List(fields.Nested(integrate_page_fields)),
-        'total': fields.Integer
-    }
-    integrate_fields = {
-        'id': fields.String,
-        'provider': fields.String,
-        'created_at': TimestampField,
-        'is_bound': fields.Boolean,
-        'disabled': fields.Boolean,
-        'link': fields.String,
-        'source_info': fields.Nested(integrate_workspace_fields)
-    }
-    integrate_list_fields = {
-        'data': fields.List(fields.Nested(integrate_fields)),
-    }
 
     @setup_required
     @login_required
@@ -131,28 +101,6 @@ class DataSourceApi(Resource):
 
 
 class DataSourceNotionListApi(Resource):
-    integrate_icon_fields = {
-        'type': fields.String,
-        'url': fields.String,
-        'emoji': fields.String
-    }
-    integrate_page_fields = {
-        'page_name': fields.String,
-        'page_id': fields.String,
-        'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
-        'is_bound': fields.Boolean,
-        'parent_id': fields.String,
-        'type': fields.String
-    }
-    integrate_workspace_fields = {
-        'workspace_name': fields.String,
-        'workspace_id': fields.String,
-        'workspace_icon': fields.String,
-        'pages': fields.List(fields.Nested(integrate_page_fields))
-    }
-    integrate_notion_info_list_fields = {
-        'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
-    }
 
     @setup_required
     @login_required

+ 96 - 68
api/controllers/console/datasets/datasets.py

@@ -1,6 +1,9 @@
 # -*- coding:utf-8 -*-
-from flask import request
+import flask_restful
+from flask import request, current_app
 from flask_login import current_user
+
+from controllers.console.apikey import api_key_list, api_key_fields
 from core.login.login import login_required
 from flask_restful import Resource, reqparse, fields, marshal, marshal_with
 from werkzeug.exceptions import NotFound, Forbidden
@@ -12,45 +15,16 @@ from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.indexing_runner import IndexingRunner
 from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
-from core.model_providers.model_factory import ModelFactory
 from core.model_providers.models.entity.model_params import ModelType
-from libs.helper import TimestampField
+from fields.app_fields import related_app_list
+from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
+from fields.document_fields import document_status_fields
 from extensions.ext_database import db
 from models.dataset import DocumentSegment, Document
-from models.model import UploadFile
+from models.model import UploadFile, ApiToken
 from services.dataset_service import DatasetService, DocumentService
 from services.provider_service import ProviderService
 
-dataset_detail_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'description': fields.String,
-    'provider': fields.String,
-    'permission': fields.String,
-    'data_source_type': fields.String,
-    'indexing_technique': fields.String,
-    'app_count': fields.Integer,
-    'document_count': fields.Integer,
-    'word_count': fields.Integer,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'updated_by': fields.String,
-    'updated_at': TimestampField,
-    'embedding_model': fields.String,
-    'embedding_model_provider': fields.String,
-    'embedding_available': fields.Boolean
-}
-
-dataset_query_detail_fields = {
-    "id": fields.String,
-    "content": fields.String,
-    "source": fields.String,
-    "source_app_id": fields.String,
-    "created_by_role": fields.String,
-    "created_by": fields.String,
-    "created_at": TimestampField
-}
-
 
 def _validate_name(name):
     if not name or len(name) < 1 or len(name) > 40:
@@ -82,7 +56,8 @@ class DatasetListApi(Resource):
 
         # check embedding setting
         provider_service = ProviderService()
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
+        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
+                                                                 ModelType.EMBEDDINGS.value)
         # if len(valid_model_list) == 0:
         #     raise ProviderNotInitializeError(
         #         f"No Embedding Model available. Please configure a valid provider "
@@ -157,7 +132,8 @@ class DatasetApi(Resource):
         # check embedding setting
         provider_service = ProviderService()
         # get valid model list
-        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
+        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
+                                                                 ModelType.EMBEDDINGS.value)
         model_names = []
         for valid_model in valid_model_list:
             model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
@@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource):
         parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
-        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
+        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
         # validate args
         DocumentService.estimate_args_validate(args)
@@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource):
 
 
 class DatasetRelatedAppListApi(Resource):
-    app_detail_kernel_fields = {
-        'id': fields.String,
-        'name': fields.String,
-        'mode': fields.String,
-        'icon': fields.String,
-        'icon_background': fields.String,
-    }
-
-    related_app_list = {
-        'data': fields.List(fields.Nested(app_detail_kernel_fields)),
-        'total': fields.Integer,
-    }
 
     @setup_required
     @login_required
@@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
 
 
 class DatasetIndexingStatusApi(Resource):
-    document_status_fields = {
-        'id': fields.String,
-        'indexing_status': fields.String,
-        'processing_started_at': TimestampField,
-        'parsing_completed_at': TimestampField,
-        'cleaning_completed_at': TimestampField,
-        'splitting_completed_at': TimestampField,
-        'completed_at': TimestampField,
-        'paused_at': TimestampField,
-        'error': fields.String,
-        'stopped_at': TimestampField,
-        'completed_segments': fields.Integer,
-        'total_segments': fields.Integer,
-    }
-
-    document_status_fields_list = {
-        'data': fields.List(fields.Nested(document_status_fields))
-    }
 
     @setup_required
     @login_required
@@ -400,16 +347,97 @@ class DatasetIndexingStatusApi(Resource):
                                                           DocumentSegment.status != 're_segment').count()
             document.completed_segments = completed_segments
             document.total_segments = total_segments
-            documents_status.append(marshal(document, self.document_status_fields))
+            documents_status.append(marshal(document, document_status_fields))
         data = {
             'data': documents_status
         }
         return data
 
 
+class DatasetApiKeyApi(Resource):
+    max_keys = 10
+    token_prefix = 'dataset-'
+    resource_type = 'dataset'
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(api_key_list)
+    def get(self):
+        keys = db.session.query(ApiToken). \
+            filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
+            all()
+        return {"items": keys}
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(api_key_fields)
+    def post(self):
+        # The role of the current user in the ta table must be admin or owner
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        current_key_count = db.session.query(ApiToken). \
+            filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
+            count()
+
+        if current_key_count >= self.max_keys:
+            flask_restful.abort(
+                400,
+                message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
+                code='max_keys_exceeded'
+            )
+
+        key = ApiToken.generate_api_key(self.token_prefix, 24)
+        api_token = ApiToken()
+        api_token.tenant_id = current_user.current_tenant_id
+        api_token.token = key
+        api_token.type = self.resource_type
+        db.session.add(api_token)
+        db.session.commit()
+        return api_token, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, api_key_id):
+        api_key_id = str(api_key_id)
+
+        # The role of the current user in the ta table must be admin or owner
+        if current_user.current_tenant.current_role not in ['admin', 'owner']:
+            raise Forbidden()
+
+        key = db.session.query(ApiToken). \
+            filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
+                   ApiToken.id == api_key_id). \
+            first()
+
+        if key is None:
+            flask_restful.abort(404, message='API key not found')
+
+        db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
+        db.session.commit()
+
+        return {'result': 'success'}, 204
+
+
+class DatasetApiBaseUrlApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        return {
+            'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
+                             else request.host_url.rstrip('/')) + '/v1'
+        }
+
+
 api.add_resource(DatasetListApi, '/datasets')
 api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
 api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
 api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
 api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
 api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
+api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
+api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')

+ 4 - 97
api/controllers/console/datasets/datasets_document.py

@@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
     LLMBadRequestError
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_redis import redis_client
+from fields.document_fields import document_with_segments_fields, document_fields, \
+    dataset_and_document_fields, document_status_fields
 from libs.helper import TimestampField
 from extensions.ext_database import db
 from models.dataset import DatasetProcessRule, Dataset
@@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService
 from tasks.add_document_to_index_task import add_document_to_index_task
 from tasks.remove_document_from_index_task import remove_document_from_index_task
 
-dataset_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'description': fields.String,
-    'permission': fields.String,
-    'data_source_type': fields.String,
-    'indexing_technique': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-}
-
-document_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'data_source_type': fields.String,
-    'data_source_info': fields.Raw(attribute='data_source_info_dict'),
-    'dataset_process_rule_id': fields.String,
-    'name': fields.String,
-    'created_from': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'tokens': fields.Integer,
-    'indexing_status': fields.String,
-    'error': fields.String,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'archived': fields.Boolean,
-    'display_status': fields.String,
-    'word_count': fields.Integer,
-    'hit_count': fields.Integer,
-    'doc_form': fields.String,
-}
-
-document_with_segments_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'data_source_type': fields.String,
-    'data_source_info': fields.Raw(attribute='data_source_info_dict'),
-    'dataset_process_rule_id': fields.String,
-    'name': fields.String,
-    'created_from': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'tokens': fields.Integer,
-    'indexing_status': fields.String,
-    'error': fields.String,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'archived': fields.Boolean,
-    'display_status': fields.String,
-    'word_count': fields.Integer,
-    'hit_count': fields.Integer,
-    'completed_segments': fields.Integer,
-    'total_segments': fields.Integer
-}
-
 
 class DocumentResource(Resource):
     def get_document(self, dataset_id: str, document_id: str) -> Document:
@@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource):
 
 
 class DatasetInitApi(Resource):
-    dataset_and_document_fields = {
-        'dataset': fields.Nested(dataset_fields),
-        'documents': fields.List(fields.Nested(document_fields)),
-        'batch': fields.String
-    }
 
     @setup_required
     @login_required
@@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
 
 
 class DocumentBatchIndexingStatusApi(DocumentResource):
-    document_status_fields = {
-        'id': fields.String,
-        'indexing_status': fields.String,
-        'processing_started_at': TimestampField,
-        'parsing_completed_at': TimestampField,
-        'cleaning_completed_at': TimestampField,
-        'splitting_completed_at': TimestampField,
-        'completed_at': TimestampField,
-        'paused_at': TimestampField,
-        'error': fields.String,
-        'stopped_at': TimestampField,
-        'completed_segments': fields.Integer,
-        'total_segments': fields.Integer,
-    }
-
-    document_status_fields_list = {
-        'data': fields.List(fields.Nested(document_status_fields))
-    }
 
     @setup_required
     @login_required
@@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
             document.total_segments = total_segments
             if document.is_paused:
                 document.indexing_status = 'paused'
-            documents_status.append(marshal(document, self.document_status_fields))
+            documents_status.append(marshal(document, document_status_fields))
         data = {
             'data': documents_status
         }
@@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
 
 
 class DocumentIndexingStatusApi(DocumentResource):
-    document_status_fields = {
-        'id': fields.String,
-        'indexing_status': fields.String,
-        'processing_started_at': TimestampField,
-        'parsing_completed_at': TimestampField,
-        'cleaning_completed_at': TimestampField,
-        'splitting_completed_at': TimestampField,
-        'completed_at': TimestampField,
-        'paused_at': TimestampField,
-        'error': fields.String,
-        'stopped_at': TimestampField,
-        'completed_segments': fields.Integer,
-        'total_segments': fields.Integer,
-    }
 
     @setup_required
     @login_required
@@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource):
         document.total_segments = total_segments
         if document.is_paused:
             document.indexing_status = 'paused'
-        return marshal(document, self.document_status_fields)
+        return marshal(document, document_status_fields)
 
 
 class DocumentDetailApi(DocumentResource):

+ 2 - 31
api/controllers/console/datasets/datasets_segments.py

@@ -3,7 +3,7 @@ import uuid
 from datetime import datetime
 from flask import request
 from flask_login import current_user
-from flask_restful import Resource, reqparse, fields, marshal
+from flask_restful import Resource, reqparse, marshal
 from werkzeug.exceptions import NotFound, Forbidden
 
 import services
@@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory
 from core.login.login import login_required
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
+from fields.segment_fields import segment_fields
 from models.dataset import DocumentSegment
 
 from libs.helper import TimestampField
@@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas
 from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
 import pandas as pd
 
-segment_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'document_id': fields.String,
-    'content': fields.String,
-    'answer': fields.String,
-    'word_count': fields.Integer,
-    'tokens': fields.Integer,
-    'keywords': fields.List(fields.String),
-    'index_node_id': fields.String,
-    'index_node_hash': fields.String,
-    'hit_count': fields.Integer,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'status': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'indexing_at': TimestampField,
-    'completed_at': TimestampField,
-    'error': fields.String,
-    'stopped_at': TimestampField
-}
-
-segment_list_response = {
-    'data': fields.List(fields.Nested(segment_fields)),
-    'has_more': fields.Boolean,
-    'limit': fields.Integer
-}
-
 
 class DatasetDocumentSegmentListApi(Resource):
     @setup_required

+ 12 - 86
api/controllers/console/datasets/file.py

@@ -1,28 +1,19 @@
-import datetime
-import hashlib
-import tempfile
-import chardet
-import time
-import uuid
-from pathlib import Path
-
 from cachetools import TTLCache
 from flask import request, current_app
-from flask_login import current_user
+
+import services
 from core.login.login import login_required
 from flask_restful import Resource, marshal_with, fields
-from werkzeug.exceptions import NotFound
 
 from controllers.console import api
 from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
     UnsupportedFileTypeError
+
 from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
-from core.data_loader.file_extractor import FileExtractor
-from extensions.ext_storage import storage
-from libs.helper import TimestampField
-from extensions.ext_database import db
-from models.model import UploadFile
+from fields.file_fields import upload_config_fields, file_fields
+
+from services.file_service import FileService
 
 cache = TTLCache(maxsize=None, ttl=30)
 
@@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000
 
 
 class FileApi(Resource):
-    upload_config_fields = {
-        'file_size_limit': fields.Integer,
-        'batch_count_limit': fields.Integer
-    }
 
     @setup_required
     @login_required
@@ -48,16 +35,6 @@ class FileApi(Resource):
             'batch_count_limit': batch_count_limit
         }, 200
 
-    file_fields = {
-        'id': fields.String,
-        'name': fields.String,
-        'size': fields.Integer,
-        'extension': fields.String,
-        'mime_type': fields.String,
-        'created_by': fields.String,
-        'created_at': TimestampField,
-    }
-
     @setup_required
     @login_required
     @account_initialization_required
@@ -73,45 +50,13 @@ class FileApi(Resource):
 
         if len(request.files) > 1:
             raise TooManyFilesError()
-
-        file_content = file.read()
-        file_size = len(file_content)
-
-        file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
-        if file_size > file_size_limit:
-            message = "({file_size} > {file_size_limit})"
-            raise FileTooLargeError(message)
-
-        extension = file.filename.split('.')[-1]
-        if extension.lower() not in ALLOWED_EXTENSIONS:
+        try:
+            upload_file = FileService.upload_file(file)
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
             raise UnsupportedFileTypeError()
 
-        # user uuid as file name
-        file_uuid = str(uuid.uuid4())
-        file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
-
-        # save file to storage
-        storage.save(file_key, file_content)
-
-        # save file to db
-        config = current_app.config
-        upload_file = UploadFile(
-            tenant_id=current_user.current_tenant_id,
-            storage_type=config['STORAGE_TYPE'],
-            key=file_key,
-            name=file.filename,
-            size=file_size,
-            extension=extension,
-            mime_type=file.mimetype,
-            created_by=current_user.id,
-            created_at=datetime.datetime.utcnow(),
-            used=False,
-            hash=hashlib.sha3_256(file_content).hexdigest()
-        )
-
-        db.session.add(upload_file)
-        db.session.commit()
-
         return upload_file, 201
 
 
@@ -121,26 +66,7 @@ class FilePreviewApi(Resource):
     @account_initialization_required
     def get(self, file_id):
         file_id = str(file_id)
-
-        key = file_id + request.path
-        cached_response = cache.get(key)
-        if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
-            return cached_response['response']
-
-        upload_file = db.session.query(UploadFile) \
-            .filter(UploadFile.id == file_id) \
-            .first()
-
-        if not upload_file:
-            raise NotFound("File not found")
-
-        # extract text from file
-        extension = upload_file.extension
-        if extension.lower() not in ALLOWED_EXTENSIONS:
-            raise UnsupportedFileTypeError()
-
-        text = FileExtractor.load(upload_file, return_text=True)
-        text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
+        text = FileService.get_file_preview(file_id)
         return {'content': text}
 
 

+ 2 - 40
api/controllers/console/datasets/hit_testing.py

@@ -2,7 +2,7 @@ import logging
 
 from flask_login import current_user
 from core.login.login import login_required
-from flask_restful import Resource, reqparse, marshal, fields
+from flask_restful import Resource, reqparse, marshal
 from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
 
 import services
@@ -14,48 +14,10 @@ from controllers.console.setup import setup_required
 from controllers.console.wraps import account_initialization_required
 from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
     LLMBadRequestError
-from libs.helper import TimestampField
+from fields.hit_testing_fields import hit_testing_record_fields
 from services.dataset_service import DatasetService
 from services.hit_testing_service import HitTestingService
 
-document_fields = {
-    'id': fields.String,
-    'data_source_type': fields.String,
-    'name': fields.String,
-    'doc_type': fields.String,
-}
-
-segment_fields = {
-    'id': fields.String,
-    'position': fields.Integer,
-    'document_id': fields.String,
-    'content': fields.String,
-    'answer': fields.String,
-    'word_count': fields.Integer,
-    'tokens': fields.Integer,
-    'keywords': fields.List(fields.String),
-    'index_node_id': fields.String,
-    'index_node_hash': fields.String,
-    'hit_count': fields.Integer,
-    'enabled': fields.Boolean,
-    'disabled_at': TimestampField,
-    'disabled_by': fields.String,
-    'status': fields.String,
-    'created_by': fields.String,
-    'created_at': TimestampField,
-    'indexing_at': TimestampField,
-    'completed_at': TimestampField,
-    'error': fields.String,
-    'stopped_at': TimestampField,
-    'document': fields.Nested(document_fields),
-}
-
-hit_testing_record_fields = {
-    'segment': fields.Nested(segment_fields),
-    'score': fields.Float,
-    'tsne_position': fields.Raw
-}
-
 
 class HitTestingApi(Resource):
 

+ 2 - 16
api/controllers/console/explore/conversation.py

@@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
 from controllers.console import api
 from controllers.console.explore.error import NotChatAppError
 from controllers.console.explore.wraps import InstalledAppResource
+from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from libs.helper import TimestampField, uuid_value
 from services.conversation_service import ConversationService
 from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
 from services.web_conversation_service import WebConversationService
 
-conversation_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'inputs': fields.Raw,
-    'status': fields.String,
-    'introduction': fields.String,
-    'created_at': TimestampField
-}
-
-conversation_infinite_scroll_pagination_fields = {
-    'limit': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(conversation_fields))
-}
-
 
 class ConversationListApi(InstalledAppResource):
 
@@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
 
 class ConversationRenameApi(InstalledAppResource):
 
-    @marshal_with(conversation_fields)
+    @marshal_with(simple_conversation_fields)
     def post(self, installed_app, c_id):
         app_model = installed_app.app
         if app_model.mode != 'chat':

+ 1 - 22
api/controllers/console/explore/installed_app.py

@@ -11,32 +11,11 @@ from controllers.console import api
 from controllers.console.explore.wraps import InstalledAppResource
 from controllers.console.wraps import account_initialization_required
 from extensions.ext_database import db
+from fields.installed_app_fields import installed_app_list_fields
 from libs.helper import TimestampField
 from models.model import App, InstalledApp, RecommendedApp
 from services.account_service import TenantService
 
-app_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'mode': fields.String,
-    'icon': fields.String,
-    'icon_background': fields.String
-}
-
-installed_app_fields = {
-    'id': fields.String,
-    'app': fields.Nested(app_fields),
-    'app_owner_tenant_id': fields.String,
-    'is_pinned': fields.Boolean,
-    'last_used_at': TimestampField,
-    'editable': fields.Boolean,
-    'uninstallable': fields.Boolean,
-}
-
-installed_app_list_fields = {
-    'installed_apps': fields.List(fields.Nested(installed_app_fields))
-}
-
 
 class InstalledAppsListApi(Resource):
     @login_required

+ 1 - 39
api/controllers/console/explore/message.py

@@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
 from controllers.console.explore.wraps import InstalledAppResource
 from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
     ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
+from fields.message_fields import message_infinite_scroll_pagination_fields
 from libs.helper import uuid_value, TimestampField
 from services.completion_service import CompletionService
 from services.errors.app import MoreLikeThisDisabledError
@@ -26,45 +27,6 @@ from services.message_service import MessageService
 
 
 class MessageListApi(InstalledAppResource):
-    feedback_fields = {
-        'rating': fields.String
-    }
-
-    retriever_resource_fields = {
-        'id': fields.String,
-        'message_id': fields.String,
-        'position': fields.Integer,
-        'dataset_id': fields.String,
-        'dataset_name': fields.String,
-        'document_id': fields.String,
-        'document_name': fields.String,
-        'data_source_type': fields.String,
-        'segment_id': fields.String,
-        'score': fields.Float,
-        'hit_count': fields.Integer,
-        'word_count': fields.Integer,
-        'segment_position': fields.Integer,
-        'index_node_hash': fields.String,
-        'content': fields.String,
-        'created_at': TimestampField
-    }
-
-    message_fields = {
-        'id': fields.String,
-        'conversation_id': fields.String,
-        'inputs': fields.Raw,
-        'query': fields.String,
-        'answer': fields.String,
-        'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
-        'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
-        'created_at': TimestampField
-    }
-
-    message_infinite_scroll_pagination_fields = {
-        'limit': fields.Integer,
-        'has_more': fields.Boolean,
-        'data': fields.List(fields.Nested(message_fields))
-    }
 
     @marshal_with(message_infinite_scroll_pagination_fields)
     def get(self, installed_app):

+ 4 - 18
api/controllers/console/universal_chat/conversation.py

@@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
 
 from controllers.console import api
 from controllers.console.universal_chat.wraps import UniversalChatResource
+from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
+    conversation_with_model_config_fields
 from libs.helper import TimestampField, uuid_value
 from services.conversation_service import ConversationService
 from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
 from services.web_conversation_service import WebConversationService
 
-conversation_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'inputs': fields.Raw,
-    'status': fields.String,
-    'introduction': fields.String,
-    'created_at': TimestampField,
-    'model_config': fields.Raw,
-}
-
-conversation_infinite_scroll_pagination_fields = {
-    'limit': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(conversation_fields))
-}
-
 
 class UniversalChatConversationListApi(UniversalChatResource):
 
-    @marshal_with(conversation_infinite_scroll_pagination_fields)
+    @marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
     def get(self, universal_app):
         app_model = universal_app
 
@@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
 
 class UniversalChatConversationRenameApi(UniversalChatResource):
 
-    @marshal_with(conversation_fields)
+    @marshal_with(conversation_with_model_config_fields)
     def post(self, universal_app, c_id):
         app_model = universal_app
         conversation_id = str(c_id)

+ 1 - 1
api/controllers/service_api/__init__.py

@@ -9,4 +9,4 @@ api = ExternalApi(bp)
 
 from .app import completion, app, conversation, message, audio
 
-from .dataset import document
+from .dataset import document, segment, dataset

+ 3 - 17
api/controllers/service_api/app/conversation.py

@@ -8,25 +8,11 @@ from controllers.service_api import api
 from controllers.service_api.app import create_or_update_end_user_for_user_id
 from controllers.service_api.app.error import NotChatAppError
 from controllers.service_api.wraps import AppApiResource
+from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from libs.helper import TimestampField, uuid_value
 import services
 from services.conversation_service import ConversationService
 
-conversation_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'inputs': fields.Raw,
-    'status': fields.String,
-    'introduction': fields.String,
-    'created_at': TimestampField
-}
-
-conversation_infinite_scroll_pagination_fields = {
-    'limit': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(conversation_fields))
-}
-
 
 class ConversationApi(AppApiResource):
 
@@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
             raise NotFound("Last Conversation Not Exists.")
 
 class ConversationDetailApi(AppApiResource):
-    @marshal_with(conversation_fields)
+    @marshal_with(simple_conversation_fields)
     def delete(self, app_model, end_user, c_id):
         if app_model.mode != 'chat':
             raise NotChatAppError()
@@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
 
 class ConversationRenameApi(AppApiResource):
 
-    @marshal_with(conversation_fields)
+    @marshal_with(simple_conversation_fields)
     def post(self, app_model, end_user, c_id):
         if app_model.mode != 'chat':
             raise NotChatAppError()

+ 84 - 0
api/controllers/service_api/dataset/dataset.py

@@ -0,0 +1,84 @@
+from flask import request
+from flask_restful import reqparse, marshal
+import services.dataset_service
+from controllers.service_api import api
+from controllers.service_api.dataset.error import DatasetNameDuplicateError
+from controllers.service_api.wraps import DatasetApiResource
+from core.login.login import current_user
+from core.model_providers.models.entity.model_params import ModelType
+from extensions.ext_database import db
+from fields.dataset_fields import dataset_detail_fields
+from models.account import Account, TenantAccountJoin
+from models.dataset import Dataset
+from services.dataset_service import DatasetService
+from services.provider_service import ProviderService
+
+
+def _validate_name(name):
+    if not name or len(name) < 1 or len(name) > 40:
+        raise ValueError('Name must be between 1 to 40 characters.')
+    return name
+
+
+class DatasetApi(DatasetApiResource):
+    """Resource for get datasets."""
+
+    def get(self, tenant_id):
+        page = request.args.get('page', default=1, type=int)
+        limit = request.args.get('limit', default=20, type=int)
+        provider = request.args.get('provider', default="vendor")
+        datasets, total = DatasetService.get_datasets(page, limit, provider,
+                                                      tenant_id, current_user)
+        # check embedding setting
+        provider_service = ProviderService()
+        valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
+                                                                 ModelType.EMBEDDINGS.value)
+        model_names = []
+        for valid_model in valid_model_list:
+            model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
+        data = marshal(datasets, dataset_detail_fields)
+        for item in data:
+            if item['indexing_technique'] == 'high_quality':
+                item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
+                if item_model in model_names:
+                    item['embedding_available'] = True
+                else:
+                    item['embedding_available'] = False
+            else:
+                item['embedding_available'] = True
+        response = {
+            'data': data,
+            'has_more': len(datasets) == limit,
+            'limit': limit,
+            'total': total,
+            'page': page
+        }
+        return response, 200
+
+    """Resource for datasets."""
+
+    def post(self, tenant_id):
+        parser = reqparse.RequestParser()
+        parser.add_argument('name', nullable=False, required=True,
+                            help='type is required. Name must be between 1 to 40 characters.',
+                            type=_validate_name)
+        parser.add_argument('indexing_technique', type=str, location='json',
+                            choices=('high_quality', 'economy'),
+                            help='Invalid indexing technique.')
+        args = parser.parse_args()
+
+        try:
+            dataset = DatasetService.create_empty_dataset(
+                tenant_id=tenant_id,
+                name=args['name'],
+                indexing_technique=args['indexing_technique'],
+                account=current_user
+            )
+        except services.errors.dataset.DatasetNameDuplicateError:
+            raise DatasetNameDuplicateError()
+
+        return marshal(dataset, dataset_detail_fields), 200
+
+
+api.add_resource(DatasetApi, '/datasets')
+

+ 322 - 68
api/controllers/service_api/dataset/document.py

@@ -1,114 +1,291 @@
 import datetime
+import json
 import uuid
 
-from flask import current_app
-from flask_restful import reqparse
+from flask import current_app, request
+from flask_restful import reqparse, marshal
+from sqlalchemy import desc
 from werkzeug.exceptions import NotFound
 
 import services.dataset_service
 from controllers.service_api import api
 from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
-    DatasetNotInitedError
+    NoFileUploadedError, TooManyFilesError
 from controllers.service_api.wraps import DatasetApiResource
+from core.login.login import current_user
 from core.model_providers.error import ProviderTokenNotInitError
 from extensions.ext_database import db
 from extensions.ext_storage import storage
+from fields.document_fields import document_fields, document_status_fields
+from models.dataset import Dataset, Document, DocumentSegment
 from models.model import UploadFile
 from services.dataset_service import DocumentService
+from services.file_service import FileService
 
 
-class DocumentListApi(DatasetApiResource):
+class DocumentAddByTextApi(DatasetApiResource):
     """Resource for documents."""
 
-    def post(self, dataset):
-        """Create document."""
+    def post(self, tenant_id, dataset_id):
+        """Create document by text."""
         parser = reqparse.RequestParser()
         parser.add_argument('name', type=str, required=True, nullable=False, location='json')
         parser.add_argument('text', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('doc_type', type=str, location='json')
-        parser.add_argument('doc_metadata', type=dict, location='json')
+        parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
+        parser.add_argument('original_document_id', type=str, required=False, location='json')
+        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
+        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
+                            location='json')
+        parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
+                            location='json')
+        args = parser.parse_args()
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+
+        if not dataset:
+            raise ValueError('Dataset is not exist.')
+
+        if not dataset.indexing_technique and not args['indexing_technique']:
+            raise ValueError('indexing_technique is required.')
+
+        upload_file = FileService.upload_text(args.get('text'), args.get('name'))
+        data_source = {
+            'type': 'upload_file',
+            'info_list': {
+                'data_source_type': 'upload_file',
+                'file_info_list': {
+                    'file_ids': [upload_file.id]
+                }
+            }
+        }
+        args['data_source'] = data_source
+        # validate args
+        DocumentService.document_create_args_validate(args)
+
+        try:
+            documents, batch = DocumentService.save_document_with_dataset_id(
+                dataset=dataset,
+                document_data=args,
+                account=current_user,
+                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
+                created_from='api'
+            )
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
+        document = documents[0]
+
+        documents_and_batch_fields = {
+            'document': marshal(document, document_fields),
+            'batch': batch
+        }
+        return documents_and_batch_fields, 200
+
+
+class DocumentUpdateByTextApi(DatasetApiResource):
+    """Resource for update documents."""
+
+    def post(self, tenant_id, dataset_id, document_id):
+        """Update document by text."""
+        parser = reqparse.RequestParser()
+        parser.add_argument('name', type=str, required=False, nullable=True, location='json')
+        parser.add_argument('text', type=str, required=False, nullable=True, location='json')
+        parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
+        parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
+        parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
+                            location='json')
         args = parser.parse_args()
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
 
-        if not dataset.indexing_technique:
-            raise DatasetNotInitedError("Dataset indexing technique must be set.")
-
-        doc_type = args.get('doc_type')
-        doc_metadata = args.get('doc_metadata')
-
-        if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
-            raise ValueError('Invalid doc_type.')
-
-        # user uuid as file name
-        file_uuid = str(uuid.uuid4())
-        file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt'
-
-        # save file to storage
-        storage.save(file_key, args.get('text'))
-
-        # save file to db
-        config = current_app.config
-        upload_file = UploadFile(
-            tenant_id=dataset.tenant_id,
-            storage_type=config['STORAGE_TYPE'],
-            key=file_key,
-            name=args.get('name') + '.txt',
-            size=len(args.get('text')),
-            extension='txt',
-            mime_type='text/plain',
-            created_by=dataset.created_by,
-            created_at=datetime.datetime.utcnow(),
-            used=True,
-            used_by=dataset.created_by,
-            used_at=datetime.datetime.utcnow()
-        )
-
-        db.session.add(upload_file)
-        db.session.commit()
-
-        document_data = {
-            'data_source': {
+        if not dataset:
+            raise ValueError('Dataset is not exist.')
+
+        if args['text']:
+            upload_file = FileService.upload_text(args.get('text'), args.get('name'))
+            data_source = {
                 'type': 'upload_file',
-                'info': [
-                    {
-                        'upload_file_id': upload_file.id
+                'info_list': {
+                    'data_source_type': 'upload_file',
+                    'file_info_list': {
+                        'file_ids': [upload_file.id]
                     }
-                ]
+                }
+            }
+            args['data_source'] = data_source
+        # validate args
+        args['original_document_id'] = str(document_id)
+        DocumentService.document_create_args_validate(args)
+
+        try:
+            documents, batch = DocumentService.save_document_with_dataset_id(
+                dataset=dataset,
+                document_data=args,
+                account=current_user,
+                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
+                created_from='api'
+            )
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
+        document = documents[0]
+
+        documents_and_batch_fields = {
+            'document': marshal(document, document_fields),
+            'batch': batch
+        }
+        return documents_and_batch_fields, 200
+
+
+class DocumentAddByFileApi(DatasetApiResource):
+    """Resource for documents."""
+    def post(self, tenant_id, dataset_id):
+        """Create document by upload file."""
+        args = {}
+        if 'data' in request.form:
+            args = json.loads(request.form['data'])
+        if 'doc_form' not in args:
+            args['doc_form'] = 'text_model'
+        if 'doc_language' not in args:
+            args['doc_language'] = 'English'
+        # get dataset info
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+
+        if not dataset:
+            raise ValueError('Dataset is not exist.')
+        if not dataset.indexing_technique and not args['indexing_technique']:
+            raise ValueError('indexing_technique is required.')
+
+        # save file info
+        file = request.files['file']
+        # check file
+        if 'file' not in request.files:
+            raise NoFileUploadedError()
+
+        if len(request.files) > 1:
+            raise TooManyFilesError()
+
+        upload_file = FileService.upload_file(file)
+        data_source = {
+            'type': 'upload_file',
+            'info_list': {
+                'file_info_list': {
+                    'file_ids': [upload_file.id]
+                }
             }
         }
+        args['data_source'] = data_source
+        # validate args
+        DocumentService.document_create_args_validate(args)
 
         try:
             documents, batch = DocumentService.save_document_with_dataset_id(
                 dataset=dataset,
-                document_data=document_data,
+                document_data=args,
                 account=dataset.created_by_account,
-                dataset_process_rule=dataset.latest_process_rule,
+                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
                 created_from='api'
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)
         document = documents[0]
-        if doc_type and doc_metadata:
-            metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
+        documents_and_batch_fields = {
+            'document': marshal(document, document_fields),
+            'batch': batch
+        }
+        return documents_and_batch_fields, 200
 
-            document.doc_metadata = {}
 
-            for key, value_type in metadata_schema.items():
-                value = doc_metadata.get(key)
-                if value is not None and isinstance(value, value_type):
-                    document.doc_metadata[key] = value
+class DocumentUpdateByFileApi(DatasetApiResource):
+    """Resource for update documents."""
 
-            document.doc_type = doc_type
-            document.updated_at = datetime.datetime.utcnow()
-            db.session.commit()
+    def post(self, tenant_id, dataset_id, document_id):
+        """Update document by upload file."""
+        args = {}
+        if 'data' in request.form:
+            args = json.loads(request.form['data'])
+        if 'doc_form' not in args:
+            args['doc_form'] = 'text_model'
+        if 'doc_language' not in args:
+            args['doc_language'] = 'English'
 
-        return {'id': document.id}
+        # get dataset info
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+
+        if not dataset:
+            raise ValueError('Dataset is not exist.')
+        if 'file' in request.files:
+            # save file info
+            file = request.files['file']
+
+
+            if len(request.files) > 1:
+                raise TooManyFilesError()
+
+            upload_file = FileService.upload_file(file)
+            data_source = {
+                'type': 'upload_file',
+                'info_list': {
+                    'file_info_list': {
+                        'file_ids': [upload_file.id]
+                    }
+                }
+            }
+            args['data_source'] = data_source
+        # validate args
+        args['original_document_id'] = str(document_id)
+        DocumentService.document_create_args_validate(args)
+
+        try:
+            documents, batch = DocumentService.save_document_with_dataset_id(
+                dataset=dataset,
+                document_data=args,
+                account=dataset.created_by_account,
+                dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
+                created_from='api'
+            )
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
+        document = documents[0]
+        documents_and_batch_fields = {
+            'document': marshal(document, document_fields),
+            'batch': batch
+        }
+        return documents_and_batch_fields, 200
 
 
-class DocumentApi(DatasetApiResource):
-    def delete(self, dataset, document_id):
+class DocumentDeleteApi(DatasetApiResource):
+    def delete(self, tenant_id, dataset_id, document_id):
         """Delete document."""
         document_id = str(document_id)
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+
+        # get dataset info
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+
+        if not dataset:
+            raise ValueError('Dataset is not exist.')
 
         document = DocumentService.get_document(dataset.id, document_id)
 
@@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource):
         except services.errors.document.DocumentIndexingError:
             raise DocumentIndexingError('Cannot delete document during indexing.')
 
-        return {'result': 'success'}, 204
+        return {'result': 'success'}, 200
+
+
+class DocumentListApi(DatasetApiResource):
+    def get(self, tenant_id, dataset_id):
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        page = request.args.get('page', default=1, type=int)
+        limit = request.args.get('limit', default=20, type=int)
+        search = request.args.get('keyword', default=None, type=str)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+        if not dataset:
+            raise NotFound('Dataset not found.')
+
+        query = Document.query.filter_by(
+            dataset_id=str(dataset_id), tenant_id=tenant_id)
+
+        if search:
+            search = f'%{search}%'
+            query = query.filter(Document.name.like(search))
+
+        query = query.order_by(desc(Document.created_at))
+
+        paginated_documents = query.paginate(
+            page=page, per_page=limit, max_per_page=100, error_out=False)
+        documents = paginated_documents.items
+
+        response = {
+            'data': marshal(documents, document_fields),
+            'has_more': len(documents) == limit,
+            'limit': limit,
+            'total': paginated_documents.total,
+            'page': page
+        }
+
+        return response
+
+
+class DocumentIndexingStatusApi(DatasetApiResource):
+    def get(self, tenant_id, dataset_id, batch):
+        dataset_id = str(dataset_id)
+        batch = str(batch)
+        tenant_id = str(tenant_id)
+        # get dataset
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+        if not dataset:
+            raise NotFound('Dataset not found.')
+        # get documents
+        documents = DocumentService.get_batch_documents(dataset_id, batch)
+        if not documents:
+            raise NotFound('Documents not found.')
+        documents_status = []
+        for document in documents:
+            completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
+                                                              DocumentSegment.document_id == str(document.id),
+                                                              DocumentSegment.status != 're_segment').count()
+            total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
+                                                          DocumentSegment.status != 're_segment').count()
+            document.completed_segments = completed_segments
+            document.total_segments = total_segments
+            if document.is_paused:
+                document.indexing_status = 'paused'
+            documents_status.append(marshal(document, document_status_fields))
+        data = {
+            'data': documents_status
+        }
+        return data
 
 
-api.add_resource(DocumentListApi, '/documents')
-api.add_resource(DocumentApi, '/documents/<uuid:document_id>')
+api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
+api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
+api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
+api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
+api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
+api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
+api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')

+ 61 - 8
api/controllers/service_api/dataset/error.py

@@ -1,20 +1,73 @@
-# -*- coding:utf-8 -*-
 from libs.exception import BaseHTTPException
 
 
+class NoFileUploadedError(BaseHTTPException):
+    error_code = 'no_file_uploaded'
+    description = "Please upload your file."
+    code = 400
+
+
+class TooManyFilesError(BaseHTTPException):
+    error_code = 'too_many_files'
+    description = "Only one file is allowed."
+    code = 400
+
+
+class FileTooLargeError(BaseHTTPException):
+    error_code = 'file_too_large'
+    description = "File size exceeded. {message}"
+    code = 413
+
+
+class UnsupportedFileTypeError(BaseHTTPException):
+    error_code = 'unsupported_file_type'
+    description = "File type not allowed."
+    code = 415
+
+
+class HighQualityDatasetOnlyError(BaseHTTPException):
+    error_code = 'high_quality_dataset_only'
+    description = "Current operation only supports 'high-quality' datasets."
+    code = 400
+
+
+class DatasetNotInitializedError(BaseHTTPException):
+    error_code = 'dataset_not_initialized'
+    description = "The dataset is still being initialized or indexing. Please wait a moment."
+    code = 400
+
+
 class ArchivedDocumentImmutableError(BaseHTTPException):
     error_code = 'archived_document_immutable'
-    description = "Cannot operate when document was archived."
+    description = "The archived document is not editable."
     code = 403
 
 
+class DatasetNameDuplicateError(BaseHTTPException):
+    error_code = 'dataset_name_duplicate'
+    description = "The dataset name already exists. Please modify your dataset name."
+    code = 409
+
+
+class InvalidActionError(BaseHTTPException):
+    error_code = 'invalid_action'
+    description = "Invalid action."
+    code = 400
+
+
+class DocumentAlreadyFinishedError(BaseHTTPException):
+    error_code = 'document_already_finished'
+    description = "The document has been processed. Please refresh the page or go to the document details."
+    code = 400
+
+
 class DocumentIndexingError(BaseHTTPException):
     error_code = 'document_indexing'
-    description = "Cannot operate document during indexing."
-    code = 403
+    description = "The document is being processed and cannot be edited."
+    code = 400
 
 
-class DatasetNotInitedError(BaseHTTPException):
-    error_code = 'dataset_not_inited'
-    description = "The dataset is still being initialized or indexing. Please wait a moment."
-    code = 403
+class InvalidMetadataError(BaseHTTPException):
+    error_code = 'invalid_metadata'
+    description = "The metadata content is incorrect. Please check and verify."
+    code = 400

+ 59 - 0
api/controllers/service_api/dataset/segment.py

@@ -0,0 +1,59 @@
+from flask_login import current_user
+from flask_restful import reqparse, marshal
+from werkzeug.exceptions import NotFound
+
+from controllers.service_api import api
+from controllers.service_api.app.error import ProviderNotInitializeError
+from controllers.service_api.wraps import DatasetApiResource
+from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
+from core.model_providers.model_factory import ModelFactory
+from extensions.ext_database import db
+from fields.segment_fields import segment_fields
+from models.dataset import Dataset
+from services.dataset_service import DocumentService, SegmentService
+
+
+class SegmentApi(DatasetApiResource):
+    """Resource for segments."""
+    def post(self, tenant_id, dataset_id, document_id):
+        """Create single segment."""
+        # check dataset
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset.id, document_id)
+        if not document:
+            raise NotFound('Document not found.')
+        # check embedding model setting
+        if dataset.indexing_technique == 'high_quality':
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=current_user.current_tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
+        # validate args
+        parser = reqparse.RequestParser()
+        parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
+        args = parser.parse_args()
+        for args_item in args['segments']:
+            SegmentService.segment_create_args_validate(args_item, document)
+        segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
+        return {
+            'data': marshal(segments, segment_fields),
+            'doc_form': document.doc_form
+        }, 200
+
+
+api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')

+ 22 - 7
api/controllers/service_api/wraps.py

@@ -2,11 +2,14 @@
 from datetime import datetime
 from functools import wraps
 
-from flask import request
+from flask import request, current_app
+from flask_login import user_logged_in
 from flask_restful import Resource
 from werkzeug.exceptions import NotFound, Unauthorized
 
+from core.login.login import _get_user
 from extensions.ext_database import db
+from models.account import Tenant, TenantAccountJoin, Account
 from models.dataset import Dataset
 from models.model import ApiToken, App
 
@@ -43,12 +46,24 @@ def validate_dataset_token(view=None):
         @wraps(view)
         def decorated(*args, **kwargs):
             api_token = validate_and_get_api_token('dataset')
-
-            dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
-            if not dataset:
-                raise NotFound()
-
-            return view(dataset, *args, **kwargs)
+            tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
+                .filter(Tenant.id == api_token.tenant_id) \
+                .filter(TenantAccountJoin.tenant_id == Tenant.id) \
+                .filter(TenantAccountJoin.role == 'owner') \
+                .one_or_none()
+            if tenant_account_join:
+                tenant, ta = tenant_account_join
+                account = Account.query.filter_by(id=ta.account_id).first()
+                # Login admin
+                if account:
+                    account.current_tenant = tenant
+                    current_app.login_manager._update_request_context_with_user(account)
+                    user_logged_in.send(current_app._get_current_object(), user=_get_user())
+                else:
+                    raise Unauthorized("Tenant owner account is not exist.")
+            else:
+                raise Unauthorized("Tenant is not exist.")
+            return view(api_token.tenant_id, *args, **kwargs)
         return decorated
 
     if view:

+ 2 - 16
api/controllers/web/conversation.py

@@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
 from controllers.web import api
 from controllers.web.error import NotChatAppError
 from controllers.web.wraps import WebApiResource
+from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
 from libs.helper import TimestampField, uuid_value
 from services.conversation_service import ConversationService
 from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
 from services.web_conversation_service import WebConversationService
 
-conversation_fields = {
-    'id': fields.String,
-    'name': fields.String,
-    'inputs': fields.Raw,
-    'status': fields.String,
-    'introduction': fields.String,
-    'created_at': TimestampField
-}
-
-conversation_infinite_scroll_pagination_fields = {
-    'limit': fields.Integer,
-    'has_more': fields.Boolean,
-    'data': fields.List(fields.Nested(conversation_fields))
-}
-
 
 class ConversationListApi(WebApiResource):
 
@@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
 
 class ConversationRenameApi(WebApiResource):
 
-    @marshal_with(conversation_fields)
+    @marshal_with(simple_conversation_fields)
     def post(self, app_model, end_user, c_id):
         if app_model.mode != 'chat':
             raise NotChatAppError()

+ 1 - 0
api/core/data_loader/loader/notion.py

@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
 BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
 DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
 SEARCH_URL = "https://api.notion.com/v1/search"
+
 RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
 RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
 HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']

+ 17 - 0
api/core/index/keyword_table_index/keyword_table_index.py

@@ -246,11 +246,28 @@ class KeywordTableIndex(BaseIndex):
         keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
         self._save_dataset_keyword_table(keyword_table)
 
+    def multi_create_segment_keywords(self, pre_segment_data_list: list):
+        keyword_table_handler = JiebaKeywordTableHandler()
+        keyword_table = self._get_dataset_keyword_table()
+        for pre_segment_data in pre_segment_data_list:
+            segment = pre_segment_data['segment']
+            if pre_segment_data['keywords']:
+                segment.keywords = pre_segment_data['keywords']
+                keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
+                                                                pre_segment_data['keywords'])
+            else:
+                keywords = keyword_table_handler.extract_keywords(segment.content,
+                                                                  self._config.max_keywords_per_chunk)
+                segment.keywords = list(keywords)
+                keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
+        self._save_dataset_keyword_table(keyword_table)
+
     def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
         keyword_table = self._get_dataset_keyword_table()
         keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
         self._save_dataset_keyword_table(keyword_table)
 
+
 class KeywordTableRetriever(BaseRetriever, BaseModel):
     index: KeywordTableIndex
     search_kwargs: dict = Field(default_factory=dict)

+ 0 - 0
api/fields/__init__.py


+ 138 - 0
api/fields/app_fields.py

@@ -0,0 +1,138 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+app_detail_kernel_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'mode': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String,
+}
+
+related_app_list = {
+    'data': fields.List(fields.Nested(app_detail_kernel_fields)),
+    'total': fields.Integer,
+}
+
+model_config_fields = {
+    'opening_statement': fields.String,
+    'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
+    'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
+    'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
+    'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
+    'more_like_this': fields.Raw(attribute='more_like_this_dict'),
+    'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
+    'model': fields.Raw(attribute='model_dict'),
+    'user_input_form': fields.Raw(attribute='user_input_form_list'),
+    'dataset_query_variable': fields.String,
+    'pre_prompt': fields.String,
+    'agent_mode': fields.Raw(attribute='agent_mode_dict'),
+}
+
+app_detail_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'mode': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String,
+    'enable_site': fields.Boolean,
+    'enable_api': fields.Boolean,
+    'api_rpm': fields.Integer,
+    'api_rph': fields.Integer,
+    'is_demo': fields.Boolean,
+    'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
+    'created_at': TimestampField
+}
+
+prompt_config_fields = {
+    'prompt_template': fields.String,
+}
+
+model_config_partial_fields = {
+    'model': fields.Raw(attribute='model_dict'),
+    'pre_prompt': fields.String,
+}
+
+app_partial_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'mode': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String,
+    'enable_site': fields.Boolean,
+    'enable_api': fields.Boolean,
+    'is_demo': fields.Boolean,
+    'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
+    'created_at': TimestampField
+}
+
+app_pagination_fields = {
+    'page': fields.Integer,
+    'limit': fields.Integer(attribute='per_page'),
+    'total': fields.Integer,
+    'has_more': fields.Boolean(attribute='has_next'),
+    'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
+}
+
+template_fields = {
+    'name': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String,
+    'description': fields.String,
+    'mode': fields.String,
+    'model_config': fields.Nested(model_config_fields),
+}
+
+template_list_fields = {
+    'data': fields.List(fields.Nested(template_fields)),
+}
+
+site_fields = {
+    'access_token': fields.String(attribute='code'),
+    'code': fields.String,
+    'title': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String,
+    'description': fields.String,
+    'default_language': fields.String,
+    'customize_domain': fields.String,
+    'copyright': fields.String,
+    'privacy_policy': fields.String,
+    'customize_token_strategy': fields.String,
+    'prompt_public': fields.Boolean,
+    'app_base_url': fields.String,
+}
+
+app_detail_fields_with_site = {
+    'id': fields.String,
+    'name': fields.String,
+    'mode': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String,
+    'enable_site': fields.Boolean,
+    'enable_api': fields.Boolean,
+    'api_rpm': fields.Integer,
+    'api_rph': fields.Integer,
+    'is_demo': fields.Boolean,
+    'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
+    'site': fields.Nested(site_fields),
+    'api_base_url': fields.String,
+    'created_at': TimestampField
+}
+
+app_site_fields = {
+    'app_id': fields.String,
+    'access_token': fields.String(attribute='code'),
+    'code': fields.String,
+    'title': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String,
+    'description': fields.String,
+    'default_language': fields.String,
+    'customize_domain': fields.String,
+    'copyright': fields.String,
+    'privacy_policy': fields.String,
+    'customize_token_strategy': fields.String,
+    'prompt_public': fields.Boolean
+}

+ 182 - 0
api/fields/conversation_fields.py

@@ -0,0 +1,182 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+
+class MessageTextField(fields.Raw):
+    def format(self, value):
+        return value[0]['text'] if value else ''
+
+
+account_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'email': fields.String
+}
+
+feedback_fields = {
+    'rating': fields.String,
+    'content': fields.String,
+    'from_source': fields.String,
+    'from_end_user_id': fields.String,
+    'from_account': fields.Nested(account_fields, allow_null=True),
+}
+
+annotation_fields = {
+    'content': fields.String,
+    'account': fields.Nested(account_fields, allow_null=True),
+    'created_at': TimestampField
+}
+
+message_detail_fields = {
+    'id': fields.String,
+    'conversation_id': fields.String,
+    'inputs': fields.Raw,
+    'query': fields.String,
+    'message': fields.Raw,
+    'message_tokens': fields.Integer,
+    'answer': fields.String,
+    'answer_tokens': fields.Integer,
+    'provider_response_latency': fields.Float,
+    'from_source': fields.String,
+    'from_end_user_id': fields.String,
+    'from_account_id': fields.String,
+    'feedbacks': fields.List(fields.Nested(feedback_fields)),
+    'annotation': fields.Nested(annotation_fields, allow_null=True),
+    'created_at': TimestampField
+}
+
+feedback_stat_fields = {
+    'like': fields.Integer,
+    'dislike': fields.Integer
+}
+
+model_config_fields = {
+    'opening_statement': fields.String,
+    'suggested_questions': fields.Raw,
+    'model': fields.Raw,
+    'user_input_form': fields.Raw,
+    'pre_prompt': fields.String,
+    'agent_mode': fields.Raw,
+}
+
+simple_configs_fields = {
+    'prompt_template': fields.String,
+}
+
+simple_model_config_fields = {
+    'model': fields.Raw(attribute='model_dict'),
+    'pre_prompt': fields.String,
+}
+
+simple_message_detail_fields = {
+    'inputs': fields.Raw,
+    'query': fields.String,
+    'message': MessageTextField,
+    'answer': fields.String,
+}
+
+conversation_fields = {
+    'id': fields.String,
+    'status': fields.String,
+    'from_source': fields.String,
+    'from_end_user_id': fields.String,
+    'from_end_user_session_id': fields.String(),
+    'from_account_id': fields.String,
+    'read_at': TimestampField,
+    'created_at': TimestampField,
+    'annotation': fields.Nested(annotation_fields, allow_null=True),
+    'model_config': fields.Nested(simple_model_config_fields),
+    'user_feedback_stats': fields.Nested(feedback_stat_fields),
+    'admin_feedback_stats': fields.Nested(feedback_stat_fields),
+    'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
+}
+
+conversation_pagination_fields = {
+    'page': fields.Integer,
+    'limit': fields.Integer(attribute='per_page'),
+    'total': fields.Integer,
+    'has_more': fields.Boolean(attribute='has_next'),
+    'data': fields.List(fields.Nested(conversation_fields), attribute='items')
+}
+
+conversation_message_detail_fields = {
+    'id': fields.String,
+    'status': fields.String,
+    'from_source': fields.String,
+    'from_end_user_id': fields.String,
+    'from_account_id': fields.String,
+    'created_at': TimestampField,
+    'model_config': fields.Nested(model_config_fields),
+    'message': fields.Nested(message_detail_fields, attribute='first_message'),
+}
+
+simple_model_config_fields = {
+    'model': fields.Raw(attribute='model_dict'),
+    'pre_prompt': fields.String,
+}
+
+conversation_with_summary_fields = {
+    'id': fields.String,
+    'status': fields.String,
+    'from_source': fields.String,
+    'from_end_user_id': fields.String,
+    'from_end_user_session_id': fields.String,
+    'from_account_id': fields.String,
+    'summary': fields.String(attribute='summary_or_query'),
+    'read_at': TimestampField,
+    'created_at': TimestampField,
+    'annotated': fields.Boolean,
+    'model_config': fields.Nested(simple_model_config_fields),
+    'message_count': fields.Integer,
+    'user_feedback_stats': fields.Nested(feedback_stat_fields),
+    'admin_feedback_stats': fields.Nested(feedback_stat_fields)
+}
+
+conversation_with_summary_pagination_fields = {
+    'page': fields.Integer,
+    'limit': fields.Integer(attribute='per_page'),
+    'total': fields.Integer,
+    'has_more': fields.Boolean(attribute='has_next'),
+    'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
+}
+
+conversation_detail_fields = {
+    'id': fields.String,
+    'status': fields.String,
+    'from_source': fields.String,
+    'from_end_user_id': fields.String,
+    'from_account_id': fields.String,
+    'created_at': TimestampField,
+    'annotated': fields.Boolean,
+    'model_config': fields.Nested(model_config_fields),
+    'message_count': fields.Integer,
+    'user_feedback_stats': fields.Nested(feedback_stat_fields),
+    'admin_feedback_stats': fields.Nested(feedback_stat_fields)
+}
+
+simple_conversation_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'inputs': fields.Raw,
+    'status': fields.String,
+    'introduction': fields.String,
+    'created_at': TimestampField
+}
+
+conversation_infinite_scroll_pagination_fields = {
+    'limit': fields.Integer,
+    'has_more': fields.Boolean,
+    'data': fields.List(fields.Nested(simple_conversation_fields))
+}
+
+conversation_with_model_config_fields = {
+    **simple_conversation_fields,
+    'model_config': fields.Raw,
+}
+
+conversation_with_model_config_infinite_scroll_pagination_fields = {
+    'limit': fields.Integer,
+    'has_more': fields.Boolean,
+    'data': fields.List(fields.Nested(conversation_with_model_config_fields))
+}

+ 65 - 0
api/fields/data_source_fields.py

@@ -0,0 +1,65 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+integrate_icon_fields = {
+    'type': fields.String,
+    'url': fields.String,
+    'emoji': fields.String
+}
+
+integrate_page_fields = {
+    'page_name': fields.String,
+    'page_id': fields.String,
+    'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
+    'is_bound': fields.Boolean,
+    'parent_id': fields.String,
+    'type': fields.String
+}
+
+integrate_workspace_fields = {
+    'workspace_name': fields.String,
+    'workspace_id': fields.String,
+    'workspace_icon': fields.String,
+    'pages': fields.List(fields.Nested(integrate_page_fields))
+}
+
+integrate_notion_info_list_fields = {
+    'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
+}
+
+integrate_icon_fields = {
+    'type': fields.String,
+    'url': fields.String,
+    'emoji': fields.String
+}
+
+integrate_page_fields = {
+    'page_name': fields.String,
+    'page_id': fields.String,
+    'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
+    'parent_id': fields.String,
+    'type': fields.String
+}
+
+integrate_workspace_fields = {
+    'workspace_name': fields.String,
+    'workspace_id': fields.String,
+    'workspace_icon': fields.String,
+    'pages': fields.List(fields.Nested(integrate_page_fields)),
+    'total': fields.Integer
+}
+
+integrate_fields = {
+    'id': fields.String,
+    'provider': fields.String,
+    'created_at': TimestampField,
+    'is_bound': fields.Boolean,
+    'disabled': fields.Boolean,
+    'link': fields.String,
+    'source_info': fields.Nested(integrate_workspace_fields)
+}
+
+integrate_list_fields = {
+    'data': fields.List(fields.Nested(integrate_fields)),
+}

+ 43 - 0
api/fields/dataset_fields.py

@@ -0,0 +1,43 @@
+from flask_restful import fields
+from libs.helper import TimestampField
+
+dataset_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'description': fields.String,
+    'permission': fields.String,
+    'data_source_type': fields.String,
+    'indexing_technique': fields.String,
+    'created_by': fields.String,
+    'created_at': TimestampField,
+}
+
+dataset_detail_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'description': fields.String,
+    'provider': fields.String,
+    'permission': fields.String,
+    'data_source_type': fields.String,
+    'indexing_technique': fields.String,
+    'app_count': fields.Integer,
+    'document_count': fields.Integer,
+    'word_count': fields.Integer,
+    'created_by': fields.String,
+    'created_at': TimestampField,
+    'updated_by': fields.String,
+    'updated_at': TimestampField,
+    'embedding_model': fields.String,
+    'embedding_model_provider': fields.String,
+    'embedding_available': fields.Boolean
+}
+
+dataset_query_detail_fields = {
+    "id": fields.String,
+    "content": fields.String,
+    "source": fields.String,
+    "source_app_id": fields.String,
+    "created_by_role": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField
+}

+ 76 - 0
api/fields/document_fields.py

@@ -0,0 +1,76 @@
+from flask_restful import fields
+
+from fields.dataset_fields import dataset_fields
+from libs.helper import TimestampField
+
+document_fields = {
+    'id': fields.String,
+    'position': fields.Integer,
+    'data_source_type': fields.String,
+    'data_source_info': fields.Raw(attribute='data_source_info_dict'),
+    'dataset_process_rule_id': fields.String,
+    'name': fields.String,
+    'created_from': fields.String,
+    'created_by': fields.String,
+    'created_at': TimestampField,
+    'tokens': fields.Integer,
+    'indexing_status': fields.String,
+    'error': fields.String,
+    'enabled': fields.Boolean,
+    'disabled_at': TimestampField,
+    'disabled_by': fields.String,
+    'archived': fields.Boolean,
+    'display_status': fields.String,
+    'word_count': fields.Integer,
+    'hit_count': fields.Integer,
+    'doc_form': fields.String,
+}
+
+document_with_segments_fields = {
+    'id': fields.String,
+    'position': fields.Integer,
+    'data_source_type': fields.String,
+    'data_source_info': fields.Raw(attribute='data_source_info_dict'),
+    'dataset_process_rule_id': fields.String,
+    'name': fields.String,
+    'created_from': fields.String,
+    'created_by': fields.String,
+    'created_at': TimestampField,
+    'tokens': fields.Integer,
+    'indexing_status': fields.String,
+    'error': fields.String,
+    'enabled': fields.Boolean,
+    'disabled_at': TimestampField,
+    'disabled_by': fields.String,
+    'archived': fields.Boolean,
+    'display_status': fields.String,
+    'word_count': fields.Integer,
+    'hit_count': fields.Integer,
+    'completed_segments': fields.Integer,
+    'total_segments': fields.Integer
+}
+
+dataset_and_document_fields = {
+    'dataset': fields.Nested(dataset_fields),
+    'documents': fields.List(fields.Nested(document_fields)),
+    'batch': fields.String
+}
+
+document_status_fields = {
+    'id': fields.String,
+    'indexing_status': fields.String,
+    'processing_started_at': TimestampField,
+    'parsing_completed_at': TimestampField,
+    'cleaning_completed_at': TimestampField,
+    'splitting_completed_at': TimestampField,
+    'completed_at': TimestampField,
+    'paused_at': TimestampField,
+    'error': fields.String,
+    'stopped_at': TimestampField,
+    'completed_segments': fields.Integer,
+    'total_segments': fields.Integer,
+}
+
+document_status_fields_list = {
+    'data': fields.List(fields.Nested(document_status_fields))
+}

+ 18 - 0
api/fields/file_fields.py

@@ -0,0 +1,18 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+upload_config_fields = {
+    'file_size_limit': fields.Integer,
+    'batch_count_limit': fields.Integer
+}
+
+file_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'size': fields.Integer,
+    'extension': fields.String,
+    'mime_type': fields.String,
+    'created_by': fields.String,
+    'created_at': TimestampField,
+}

+ 41 - 0
api/fields/hit_testing_fields.py

@@ -0,0 +1,41 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+document_fields = {
+    'id': fields.String,
+    'data_source_type': fields.String,
+    'name': fields.String,
+    'doc_type': fields.String,
+}
+
+segment_fields = {
+    'id': fields.String,
+    'position': fields.Integer,
+    'document_id': fields.String,
+    'content': fields.String,
+    'answer': fields.String,
+    'word_count': fields.Integer,
+    'tokens': fields.Integer,
+    'keywords': fields.List(fields.String),
+    'index_node_id': fields.String,
+    'index_node_hash': fields.String,
+    'hit_count': fields.Integer,
+    'enabled': fields.Boolean,
+    'disabled_at': TimestampField,
+    'disabled_by': fields.String,
+    'status': fields.String,
+    'created_by': fields.String,
+    'created_at': TimestampField,
+    'indexing_at': TimestampField,
+    'completed_at': TimestampField,
+    'error': fields.String,
+    'stopped_at': TimestampField,
+    'document': fields.Nested(document_fields),
+}
+
+hit_testing_record_fields = {
+    'segment': fields.Nested(segment_fields),
+    'score': fields.Float,
+    'tsne_position': fields.Raw
+}

+ 25 - 0
api/fields/installed_app_fields.py

@@ -0,0 +1,25 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+app_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'mode': fields.String,
+    'icon': fields.String,
+    'icon_background': fields.String
+}
+
+installed_app_fields = {
+    'id': fields.String,
+    'app': fields.Nested(app_fields),
+    'app_owner_tenant_id': fields.String,
+    'is_pinned': fields.Boolean,
+    'last_used_at': TimestampField,
+    'editable': fields.Boolean,
+    'uninstallable': fields.Boolean,
+}
+
+installed_app_list_fields = {
+    'installed_apps': fields.List(fields.Nested(installed_app_fields))
+}

+ 43 - 0
api/fields/message_fields.py

@@ -0,0 +1,43 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+feedback_fields = {
+    'rating': fields.String
+}
+
+retriever_resource_fields = {
+    'id': fields.String,
+    'message_id': fields.String,
+    'position': fields.Integer,
+    'dataset_id': fields.String,
+    'dataset_name': fields.String,
+    'document_id': fields.String,
+    'document_name': fields.String,
+    'data_source_type': fields.String,
+    'segment_id': fields.String,
+    'score': fields.Float,
+    'hit_count': fields.Integer,
+    'word_count': fields.Integer,
+    'segment_position': fields.Integer,
+    'index_node_hash': fields.String,
+    'content': fields.String,
+    'created_at': TimestampField
+}
+
+message_fields = {
+    'id': fields.String,
+    'conversation_id': fields.String,
+    'inputs': fields.Raw,
+    'query': fields.String,
+    'answer': fields.String,
+    'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
+    'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
+    'created_at': TimestampField
+}
+
+message_infinite_scroll_pagination_fields = {
+    'limit': fields.Integer,
+    'has_more': fields.Boolean,
+    'data': fields.List(fields.Nested(message_fields))
+}

+ 32 - 0
api/fields/segment_fields.py

@@ -0,0 +1,32 @@
+from flask_restful import fields
+from libs.helper import TimestampField
+
+segment_fields = {
+    'id': fields.String,
+    'position': fields.Integer,
+    'document_id': fields.String,
+    'content': fields.String,
+    'answer': fields.String,
+    'word_count': fields.Integer,
+    'tokens': fields.Integer,
+    'keywords': fields.List(fields.String),
+    'index_node_id': fields.String,
+    'index_node_hash': fields.String,
+    'hit_count': fields.Integer,
+    'enabled': fields.Boolean,
+    'disabled_at': TimestampField,
+    'disabled_by': fields.String,
+    'status': fields.String,
+    'created_by': fields.String,
+    'created_at': TimestampField,
+    'indexing_at': TimestampField,
+    'completed_at': TimestampField,
+    'error': fields.String,
+    'stopped_at': TimestampField
+}
+
+segment_list_response = {
+    'data': fields.List(fields.Nested(segment_fields)),
+    'has_more': fields.Boolean,
+    'limit': fields.Integer
+}

+ 36 - 0
api/migrations/versions/2e9819ca5b28_add_tenant_id_in_api_token.py

@@ -0,0 +1,36 @@
+"""add_tenant_id_in_api_token
+
+Revision ID: 2e9819ca5b28
+Revises: 6e2cfb077b04
+Create Date: 2023-09-22 15:41:01.243183
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '2e9819ca5b28'
+down_revision = 'ab23c11305d4'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('api_tokens', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
+        batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
+        batch_op.drop_column('dataset_id')
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('api_tokens', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
+        batch_op.drop_index('api_token_tenant_idx')
+        batch_op.drop_column('tenant_id')
+
+    # ### end Alembic commands ###

+ 3 - 2
api/models/model.py

@@ -629,12 +629,13 @@ class ApiToken(db.Model):
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='api_token_pkey'),
         db.Index('api_token_app_id_type_idx', 'app_id', 'type'),
-        db.Index('api_token_token_idx', 'token', 'type')
+        db.Index('api_token_token_idx', 'token', 'type'),
+        db.Index('api_token_tenant_idx', 'tenant_id', 'type')
     )
 
     id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
     app_id = db.Column(UUID, nullable=True)
-    dataset_id = db.Column(UUID, nullable=True)
+    tenant_id = db.Column(UUID, nullable=True)
     type = db.Column(db.String(16), nullable=False)
     token = db.Column(db.String(255), nullable=False)
     last_used_at = db.Column(db.DateTime, nullable=True)

+ 66 - 2
api/services/dataset_service.py

@@ -96,7 +96,7 @@ class DatasetService:
         embedding_model = None
         if indexing_technique == 'high_quality':
             embedding_model = ModelFactory.get_embedding_model(
-                tenant_id=current_user.current_tenant_id
+                tenant_id=tenant_id
             )
         dataset = Dataset(name=name, indexing_technique=indexing_technique)
         # dataset = Dataset(name=name, provider=provider, config=config)
@@ -477,6 +477,7 @@ class DocumentService:
                 )
                 dataset.collection_binding_id = dataset_collection_binding.id
 
+
         documents = []
         batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
         if 'original_document_id' in document_data and document_data["original_document_id"]:
@@ -626,6 +627,9 @@ class DocumentService:
         document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
         if document.display_status != 'available':
             raise ValueError("Document is not available")
+        # update document name
+        if 'name' in document_data and document_data['name']:
+            document.name = document_data['name']
         # save process rule
         if 'process_rule' in document_data and document_data['process_rule']:
             process_rule = document_data["process_rule"]
@@ -767,7 +771,7 @@ class DocumentService:
         return dataset, documents, batch
 
     @classmethod
-    def document_create_args_validate(cls, args: dict):
+    def  document_create_args_validate(cls, args: dict):
         if 'original_document_id' not in args or not args['original_document_id']:
             DocumentService.data_source_args_validate(args)
             DocumentService.process_rule_args_validate(args)
@@ -1014,6 +1018,66 @@ class SegmentService:
         segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
         return segment
 
+    @classmethod
+    def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
+        embedding_model = None
+        if dataset.indexing_technique == 'high_quality':
+            embedding_model = ModelFactory.get_embedding_model(
+                tenant_id=dataset.tenant_id,
+                model_provider_name=dataset.embedding_model_provider,
+                model_name=dataset.embedding_model
+            )
+        max_position = db.session.query(func.max(DocumentSegment.position)).filter(
+            DocumentSegment.document_id == document.id
+        ).scalar()
+        pre_segment_data_list = []
+        segment_data_list = []
+        for segment_item in segments:
+            content = segment_item['content']
+            doc_id = str(uuid.uuid4())
+            segment_hash = helper.generate_text_hash(content)
+            tokens = 0
+            if dataset.indexing_technique == 'high_quality' and embedding_model:
+                # calc embedding use tokens
+                tokens = embedding_model.get_num_tokens(content)
+            segment_document = DocumentSegment(
+                tenant_id=current_user.current_tenant_id,
+                dataset_id=document.dataset_id,
+                document_id=document.id,
+                index_node_id=doc_id,
+                index_node_hash=segment_hash,
+                position=max_position + 1 if max_position else 1,
+                content=content,
+                word_count=len(content),
+                tokens=tokens,
+                status='completed',
+                indexing_at=datetime.datetime.utcnow(),
+                completed_at=datetime.datetime.utcnow(),
+                created_by=current_user.id
+            )
+            if document.doc_form == 'qa_model':
+                segment_document.answer = segment_item['answer']
+            db.session.add(segment_document)
+            segment_data_list.append(segment_document)
+            pre_segment_data = {
+                'segment': segment_document,
+                'keywords': segment_item['keywords']
+            }
+            pre_segment_data_list.append(pre_segment_data)
+
+        try:
+            # save vector index
+            VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)
+        except Exception as e:
+            logging.exception("create segment index failed")
+            for segment_document in segment_data_list:
+                segment_document.enabled = False
+                segment_document.disabled_at = datetime.datetime.utcnow()
+                segment_document.status = 'error'
+                segment_document.error = str(e)
+        db.session.commit()
+        return segment_data_list
+
     @classmethod
     def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):
         indexing_cache_key = 'segment_{}_indexing'.format(segment.id)

+ 1 - 1
api/services/errors/__init__.py

@@ -1,7 +1,7 @@
 # -*- coding:utf-8 -*-
 __all__ = [
     'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
-    'app', 'completion', 'audio'
+    'app', 'completion', 'audio', 'file'
 ]
 
 from . import *

+ 8 - 0
api/services/errors/file.py

@@ -3,3 +3,11 @@ from services.errors.base import BaseServiceError
 
 class FileNotExistsError(BaseServiceError):
     pass
+
+
+class FileTooLargeError(BaseServiceError):
+    description = "{message}"
+
+
+class UnsupportedFileTypeError(BaseServiceError):
+    pass

+ 123 - 0
api/services/file_service.py

@@ -0,0 +1,123 @@
+import datetime
+import hashlib
+import time
+import uuid
+
+from cachetools import TTLCache
+from flask import request, current_app
+from flask_login import current_user
+from werkzeug.datastructures import FileStorage
+from werkzeug.exceptions import NotFound
+
+from core.data_loader.file_extractor import FileExtractor
+from extensions.ext_storage import storage
+from extensions.ext_database import db
+from models.model import UploadFile
+from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
+
+ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
+PREVIEW_WORDS_LIMIT = 3000
+cache = TTLCache(maxsize=None, ttl=30)
+
+
+class FileService:
+
+    @staticmethod
+    def upload_file(file: FileStorage) -> UploadFile:
+        # read file content
+        file_content = file.read()
+        # get file size
+        file_size = len(file_content)
+
+        file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
+        if file_size > file_size_limit:
+            message = f'File size exceeded. {file_size} > {file_size_limit}'
+            raise FileTooLargeError(message)
+
+        extension = file.filename.split('.')[-1]
+        if extension.lower() not in ALLOWED_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+
+        # user uuid as file name
+        file_uuid = str(uuid.uuid4())
+        file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
+
+        # save file to storage
+        storage.save(file_key, file_content)
+
+        # save file to db
+        config = current_app.config
+        upload_file = UploadFile(
+            tenant_id=current_user.current_tenant_id,
+            storage_type=config['STORAGE_TYPE'],
+            key=file_key,
+            name=file.filename,
+            size=file_size,
+            extension=extension,
+            mime_type=file.mimetype,
+            created_by=current_user.id,
+            created_at=datetime.datetime.utcnow(),
+            used=False,
+            hash=hashlib.sha3_256(file_content).hexdigest()
+        )
+
+        db.session.add(upload_file)
+        db.session.commit()
+
+        return upload_file
+
+    @staticmethod
+    def upload_text(text: str, text_name: str) -> UploadFile:
+        # user uuid as file name
+        file_uuid = str(uuid.uuid4())
+        file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
+
+        # save file to storage
+        storage.save(file_key, text.encode('utf-8'))
+
+        # save file to db
+        config = current_app.config
+        upload_file = UploadFile(
+            tenant_id=current_user.current_tenant_id,
+            storage_type=config['STORAGE_TYPE'],
+            key=file_key,
+            name=text_name + '.txt',
+            size=len(text),
+            extension='txt',
+            mime_type='text/plain',
+            created_by=current_user.id,
+            created_at=datetime.datetime.utcnow(),
+            used=True,
+            used_by=current_user.id,
+            used_at=datetime.datetime.utcnow()
+        )
+
+        db.session.add(upload_file)
+        db.session.commit()
+
+        return upload_file
+
+    @staticmethod
+    def get_file_preview(file_id: str) -> str:
+        # get file storage key
+        key = file_id + request.path
+        cached_response = cache.get(key)
+        if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
+            return cached_response['response']
+
+        upload_file = db.session.query(UploadFile) \
+            .filter(UploadFile.id == file_id) \
+            .first()
+
+        if not upload_file:
+            raise NotFound("File not found")
+
+        # extract text from file
+        extension = upload_file.extension
+        if extension.lower() not in ALLOWED_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+
+        text = FileExtractor.load(upload_file, return_text=True)
+        text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
+
+        return text

+ 26 - 0
api/services/vector_service.py

@@ -35,6 +35,32 @@ class VectorService:
             else:
                 index.add_texts([document])
 
+    @classmethod
+    def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset):
+        documents = []
+        for pre_segment_data in pre_segment_data_list:
+            segment = pre_segment_data['segment']
+            document = Document(
+                page_content=segment.content,
+                metadata={
+                    "doc_id": segment.index_node_id,
+                    "doc_hash": segment.index_node_hash,
+                    "document_id": segment.document_id,
+                    "dataset_id": segment.dataset_id,
+                }
+            )
+            documents.append(document)
+
+        # save vector index
+        index = IndexBuilder.get_index(dataset, 'high_quality')
+        if index:
+            index.add_texts(documents, duplicate_check=True)
+
+        # save keyword index
+        keyword_index = IndexBuilder.get_index(dataset, 'economy')
+        if keyword_index:
+            keyword_index.multi_create_segment_keywords(pre_segment_data_list)
+
     @classmethod
     def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset):
         # update segment index task