Browse Source

fix: saving error in empty dataset (#2098)

crazywoola 1 year ago
parent
commit
ec1659cba0

+ 9 - 5
api/controllers/console/datasets/datasets.py

@@ -19,7 +19,7 @@ from flask import current_app, request
 from flask_login import current_user
 from flask_restful import Resource, marshal, marshal_with, reqparse
 from libs.login import login_required
-from models.dataset import Document, DocumentSegment
+from models.dataset import Dataset, Document, DocumentSegment
 from models.model import ApiToken, UploadFile
 from services.dataset_service import DatasetService, DocumentService
 from werkzeug.exceptions import Forbidden, NotFound
@@ -97,7 +97,8 @@ class DatasetListApi(Resource):
                             help='type is required. Name must be between 1 to 40 characters.',
                             type=_validate_name)
         parser.add_argument('indexing_technique', type=str, location='json',
-                            choices=('high_quality', 'economy'),
+                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+                            nullable=True,
                             help='Invalid indexing technique.')
         args = parser.parse_args()
 
@@ -177,8 +178,9 @@ class DatasetApi(Resource):
                             location='json', store_missing=False,
                             type=_validate_description_length)
         parser.add_argument('indexing_technique', type=str, location='json',
-                            choices=('high_quality', 'economy'),
-                            help='Invalid indexing technique.')
+                    choices=Dataset.INDEXING_TECHNIQUE_LIST,
+                    nullable=True,
+                    help='Invalid indexing technique.')
         parser.add_argument('permission', type=str, location='json', choices=(
             'only_me', 'all_team_members'), help='Invalid permission.')
         parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
@@ -256,7 +258,9 @@ class DatasetIndexingEstimateApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
         parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
-        parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
+        parser.add_argument('indexing_technique', type=str, required=True, 
+                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
+                            nullable=True, location='json')
         parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
         parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
         parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,

+ 2 - 1
api/controllers/service_api/dataset/dataset.py

@@ -1,3 +1,4 @@
+from models.dataset import Dataset
 import services.dataset_service
 from controllers.service_api import api
 from controllers.service_api.dataset.error import DatasetNameDuplicateError
@@ -68,7 +69,7 @@ class DatasetApi(DatasetApiResource):
                             help='type is required. Name must be between 1 to 40 characters.',
                             type=_validate_name)
         parser.add_argument('indexing_technique', type=str, location='json',
-                            choices=('high_quality', 'economy'),
+                            choices=Dataset.INDEXING_TECHNIQUE_LIST,
                             help='Invalid indexing technique.')
         args = parser.parse_args()
 

+ 1 - 1
api/models/dataset.py

@@ -17,7 +17,7 @@ class Dataset(db.Model):
         db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
     )
 
-    INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy']
+    INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
 
     id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
     tenant_id = db.Column(UUID, nullable=False)