|
@@ -21,12 +21,11 @@ from events.document_event import document_was_deleted
|
|
|
from extensions.ext_database import db
|
|
|
from extensions.ext_redis import redis_client
|
|
|
from libs import helper
|
|
|
-from models.account import Account, TenantAccountRole
|
|
|
+from models.account import Account
|
|
|
from models.dataset import (
|
|
|
AppDatasetJoin,
|
|
|
Dataset,
|
|
|
DatasetCollectionBinding,
|
|
|
- DatasetPermission,
|
|
|
DatasetProcessRule,
|
|
|
DatasetQuery,
|
|
|
Document,
|
|
@@ -57,38 +56,22 @@ class DatasetService:
|
|
|
|
|
|
@staticmethod
|
|
|
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
|
|
|
- query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id)
|
|
|
-
|
|
|
if user:
|
|
|
- if user.current_role == TenantAccountRole.DATASET_OPERATOR:
|
|
|
- dataset_permission = DatasetPermission.query.filter_by(account_id=user.id).all()
|
|
|
- if dataset_permission:
|
|
|
- dataset_ids = [dp.dataset_id for dp in dataset_permission]
|
|
|
- query = query.filter(Dataset.id.in_(dataset_ids))
|
|
|
- else:
|
|
|
- query = query.filter(db.false())
|
|
|
- else:
|
|
|
- permission_filter = db.or_(
|
|
|
- Dataset.created_by == user.id,
|
|
|
- Dataset.permission == 'all_team_members',
|
|
|
- Dataset.permission == 'partial_members',
|
|
|
- Dataset.permission == 'only_me'
|
|
|
- )
|
|
|
- query = query.filter(permission_filter)
|
|
|
+ permission_filter = db.or_(Dataset.created_by == user.id,
|
|
|
+ Dataset.permission == 'all_team_members')
|
|
|
else:
|
|
|
permission_filter = Dataset.permission == 'all_team_members'
|
|
|
- query = query.filter(permission_filter)
|
|
|
-
|
|
|
+ query = Dataset.query.filter(
|
|
|
+ db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
|
|
|
+ .order_by(Dataset.created_at.desc())
|
|
|
if search:
|
|
|
- query = query.filter(Dataset.name.ilike(f'%{search}%'))
|
|
|
-
|
|
|
+ query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
|
|
|
if tag_ids:
|
|
|
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
|
|
|
if target_ids:
|
|
|
- query = query.filter(Dataset.id.in_(target_ids))
|
|
|
+ query = query.filter(db.and_(Dataset.id.in_(target_ids)))
|
|
|
else:
|
|
|
return [], 0
|
|
|
-
|
|
|
datasets = query.paginate(
|
|
|
page=page,
|
|
|
per_page=per_page,
|
|
@@ -96,12 +79,6 @@ class DatasetService:
|
|
|
error_out=False
|
|
|
)
|
|
|
|
|
|
- # check datasets permission,
|
|
|
- if user and user.current_role != TenantAccountRole.DATASET_OPERATOR:
|
|
|
- datasets.items, datasets.total = DatasetService.filter_datasets_by_permission(
|
|
|
- user, datasets
|
|
|
- )
|
|
|
-
|
|
|
return datasets.items, datasets.total
|
|
|
|
|
|
@staticmethod
|
|
@@ -125,12 +102,9 @@ class DatasetService:
|
|
|
|
|
|
@staticmethod
|
|
|
def get_datasets_by_ids(ids, tenant_id):
|
|
|
- datasets = Dataset.query.filter(
|
|
|
- Dataset.id.in_(ids),
|
|
|
- Dataset.tenant_id == tenant_id
|
|
|
- ).paginate(
|
|
|
- page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
|
|
|
- )
|
|
|
+ datasets = Dataset.query.filter(Dataset.id.in_(ids),
|
|
|
+ Dataset.tenant_id == tenant_id).paginate(
|
|
|
+ page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
|
|
|
return datasets.items, datasets.total
|
|
|
|
|
|
@staticmethod
|
|
@@ -138,8 +112,7 @@ class DatasetService:
|
|
|
# check if dataset name already exists
|
|
|
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
|
|
raise DatasetNameDuplicateError(
|
|
|
- f'Dataset with name {name} already exists.'
|
|
|
- )
|
|
|
+ f'Dataset with name {name} already exists.')
|
|
|
embedding_model = None
|
|
|
if indexing_technique == 'high_quality':
|
|
|
model_manager = ModelManager()
|
|
@@ -178,17 +151,13 @@ class DatasetService:
|
|
|
except LLMBadRequestError:
|
|
|
raise ValueError(
|
|
|
"No Embedding Model available. Please configure a valid provider "
|
|
|
- "in the Settings -> Model Provider."
|
|
|
- )
|
|
|
+ "in the Settings -> Model Provider.")
|
|
|
except ProviderTokenNotInitError as ex:
|
|
|
- raise ValueError(
|
|
|
- f"The dataset in unavailable, due to: "
|
|
|
- f"{ex.description}"
|
|
|
- )
|
|
|
+ raise ValueError(f"The dataset in unavailable, due to: "
|
|
|
+ f"{ex.description}")
|
|
|
|
|
|
@staticmethod
|
|
|
def update_dataset(dataset_id, data, user):
|
|
|
- data.pop('partial_member_list', None)
|
|
|
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
|
|
dataset = DatasetService.get_dataset(dataset_id)
|
|
|
DatasetService.check_dataset_permission(dataset, user)
|
|
@@ -221,13 +190,12 @@ class DatasetService:
|
|
|
except LLMBadRequestError:
|
|
|
raise ValueError(
|
|
|
"No Embedding Model available. Please configure a valid provider "
|
|
|
- "in the Settings -> Model Provider."
|
|
|
- )
|
|
|
+ "in the Settings -> Model Provider.")
|
|
|
except ProviderTokenNotInitError as ex:
|
|
|
raise ValueError(ex.description)
|
|
|
else:
|
|
|
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
|
|
|
- data['embedding_model'] != dataset.embedding_model:
|
|
|
+ data['embedding_model'] != dataset.embedding_model:
|
|
|
action = 'update'
|
|
|
try:
|
|
|
model_manager = ModelManager()
|
|
@@ -247,8 +215,7 @@ class DatasetService:
|
|
|
except LLMBadRequestError:
|
|
|
raise ValueError(
|
|
|
"No Embedding Model available. Please configure a valid provider "
|
|
|
- "in the Settings -> Model Provider."
|
|
|
- )
|
|
|
+ "in the Settings -> Model Provider.")
|
|
|
except ProviderTokenNotInitError as ex:
|
|
|
raise ValueError(ex.description)
|
|
|
|
|
@@ -292,41 +259,14 @@ class DatasetService:
|
|
|
def check_dataset_permission(dataset, user):
|
|
|
if dataset.tenant_id != user.current_tenant_id:
|
|
|
logging.debug(
|
|
|
- f'User {user.id} does not have permission to access dataset {dataset.id}'
|
|
|
- )
|
|
|
+ f'User {user.id} does not have permission to access dataset {dataset.id}')
|
|
|
raise NoPermissionError(
|
|
|
- 'You do not have permission to access this dataset.'
|
|
|
- )
|
|
|
+ 'You do not have permission to access this dataset.')
|
|
|
if dataset.permission == 'only_me' and dataset.created_by != user.id:
|
|
|
logging.debug(
|
|
|
- f'User {user.id} does not have permission to access dataset {dataset.id}'
|
|
|
- )
|
|
|
+ f'User {user.id} does not have permission to access dataset {dataset.id}')
|
|
|
raise NoPermissionError(
|
|
|
- 'You do not have permission to access this dataset.'
|
|
|
- )
|
|
|
- if dataset.permission == 'partial_members':
|
|
|
- user_permission = DatasetPermission.query.filter_by(
|
|
|
- dataset_id=dataset.id, account_id=user.id
|
|
|
- ).first()
|
|
|
- if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id:
|
|
|
- logging.debug(
|
|
|
- f'User {user.id} does not have permission to access dataset {dataset.id}'
|
|
|
- )
|
|
|
- raise NoPermissionError(
|
|
|
- 'You do not have permission to access this dataset.'
|
|
|
- )
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None):
|
|
|
- if dataset.permission == 'only_me':
|
|
|
- if dataset.created_by != user.id:
|
|
|
- raise NoPermissionError('You do not have permission to access this dataset.')
|
|
|
-
|
|
|
- elif dataset.permission == 'partial_members':
|
|
|
- if not any(
|
|
|
- dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all()
|
|
|
- ):
|
|
|
- raise NoPermissionError('You do not have permission to access this dataset.')
|
|
|
+ 'You do not have permission to access this dataset.')
|
|
|
|
|
|
@staticmethod
|
|
|
def get_dataset_queries(dataset_id: str, page: int, per_page: int):
|
|
@@ -342,22 +282,6 @@ class DatasetService:
|
|
|
return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \
|
|
|
.order_by(db.desc(AppDatasetJoin.created_at)).all()
|
|
|
|
|
|
- @staticmethod
|
|
|
- def filter_datasets_by_permission(user, datasets):
|
|
|
- dataset_permission = DatasetPermission.query.filter_by(account_id=user.id).all()
|
|
|
- permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else set()
|
|
|
-
|
|
|
- filtered_datasets = [
|
|
|
- dataset for dataset in datasets if
|
|
|
- (dataset.permission == 'all_team_members') or
|
|
|
- (dataset.permission == 'only_me' and dataset.created_by == user.id) or
|
|
|
- (dataset.id in permitted_dataset_ids)
|
|
|
- ]
|
|
|
-
|
|
|
- filtered_count = len(filtered_datasets)
|
|
|
-
|
|
|
- return filtered_datasets, filtered_count
|
|
|
-
|
|
|
|
|
|
class DocumentService:
|
|
|
DEFAULT_RULES = {
|
|
@@ -623,7 +547,6 @@ class DocumentService:
|
|
|
redis_client.setex(sync_indexing_cache_key, 600, 1)
|
|
|
|
|
|
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
|
|
-
|
|
|
@staticmethod
|
|
|
def get_documents_position(dataset_id):
|
|
|
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
|
@@ -633,11 +556,9 @@ class DocumentService:
|
|
|
return 1
|
|
|
|
|
|
@staticmethod
|
|
|
- def save_document_with_dataset_id(
|
|
|
- dataset: Dataset, document_data: dict,
|
|
|
- account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
|
|
- created_from: str = 'web'
|
|
|
- ):
|
|
|
+ def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
|
|
+ account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
|
|
+ created_from: str = 'web'):
|
|
|
|
|
|
# check document limit
|
|
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
|
@@ -667,7 +588,7 @@ class DocumentService:
|
|
|
|
|
|
if not dataset.indexing_technique:
|
|
|
if 'indexing_technique' not in document_data \
|
|
|
- or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
|
|
|
+ or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:
|
|
|
raise ValueError("Indexing technique is required")
|
|
|
|
|
|
dataset.indexing_technique = document_data["indexing_technique"]
|
|
@@ -697,8 +618,7 @@ class DocumentService:
|
|
|
}
|
|
|
|
|
|
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(
|
|
|
- 'retrieval_model'
|
|
|
- ) else default_retrieval_model
|
|
|
+ 'retrieval_model') else default_retrieval_model
|
|
|
|
|
|
documents = []
|
|
|
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
|
@@ -766,14 +686,12 @@ class DocumentService:
|
|
|
documents.append(document)
|
|
|
duplicate_document_ids.append(document.id)
|
|
|
continue
|
|
|
- document = DocumentService.build_document(
|
|
|
- dataset, dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
- data_source_info, created_from, position,
|
|
|
- account, file_name, batch
|
|
|
- )
|
|
|
+ document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, file_name, batch)
|
|
|
db.session.add(document)
|
|
|
db.session.flush()
|
|
|
document_ids.append(document.id)
|
|
@@ -814,14 +732,12 @@ class DocumentService:
|
|
|
"notion_page_icon": page['page_icon'],
|
|
|
"type": page['type']
|
|
|
}
|
|
|
- document = DocumentService.build_document(
|
|
|
- dataset, dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
- data_source_info, created_from, position,
|
|
|
- account, page['page_name'], batch
|
|
|
- )
|
|
|
+ document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, page['page_name'], batch)
|
|
|
db.session.add(document)
|
|
|
db.session.flush()
|
|
|
document_ids.append(document.id)
|
|
@@ -843,14 +759,12 @@ class DocumentService:
|
|
|
'only_main_content': website_info.get('only_main_content', False),
|
|
|
'mode': 'crawl',
|
|
|
}
|
|
|
- document = DocumentService.build_document(
|
|
|
- dataset, dataset_process_rule.id,
|
|
|
- document_data["data_source"]["type"],
|
|
|
- document_data["doc_form"],
|
|
|
- document_data["doc_language"],
|
|
|
- data_source_info, created_from, position,
|
|
|
- account, url, batch
|
|
|
- )
|
|
|
+ document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
|
|
+ document_data["data_source"]["type"],
|
|
|
+ document_data["doc_form"],
|
|
|
+ document_data["doc_language"],
|
|
|
+ data_source_info, created_from, position,
|
|
|
+ account, url, batch)
|
|
|
db.session.add(document)
|
|
|
db.session.flush()
|
|
|
document_ids.append(document.id)
|
|
@@ -871,16 +785,13 @@ class DocumentService:
|
|
|
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
|
|
if count > can_upload_size:
|
|
|
raise ValueError(
|
|
|
- f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.'
|
|
|
- )
|
|
|
+ f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
|
|
|
|
|
@staticmethod
|
|
|
- def build_document(
|
|
|
- dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
|
|
- document_language: str, data_source_info: dict, created_from: str, position: int,
|
|
|
- account: Account,
|
|
|
- name: str, batch: str
|
|
|
- ):
|
|
|
+ def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
|
|
+ document_language: str, data_source_info: dict, created_from: str, position: int,
|
|
|
+ account: Account,
|
|
|
+ name: str, batch: str):
|
|
|
document = Document(
|
|
|
tenant_id=dataset.tenant_id,
|
|
|
dataset_id=dataset.id,
|
|
@@ -899,20 +810,16 @@ class DocumentService:
|
|
|
|
|
|
@staticmethod
|
|
|
def get_tenant_documents_count():
|
|
|
- documents_count = Document.query.filter(
|
|
|
- Document.completed_at.isnot(None),
|
|
|
- Document.enabled == True,
|
|
|
- Document.archived == False,
|
|
|
- Document.tenant_id == current_user.current_tenant_id
|
|
|
- ).count()
|
|
|
+ documents_count = Document.query.filter(Document.completed_at.isnot(None),
|
|
|
+ Document.enabled == True,
|
|
|
+ Document.archived == False,
|
|
|
+ Document.tenant_id == current_user.current_tenant_id).count()
|
|
|
return documents_count
|
|
|
|
|
|
@staticmethod
|
|
|
- def update_document_with_dataset_id(
|
|
|
- dataset: Dataset, document_data: dict,
|
|
|
- account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
|
|
- created_from: str = 'web'
|
|
|
- ):
|
|
|
+ def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
|
|
+ account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
|
|
+ created_from: str = 'web'):
|
|
|
DatasetService.check_dataset_model_setting(dataset)
|
|
|
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
|
|
if document.display_status != 'available':
|
|
@@ -1100,7 +1007,7 @@ class DocumentService:
|
|
|
DocumentService.process_rule_args_validate(args)
|
|
|
else:
|
|
|
if ('data_source' not in args and not args['data_source']) \
|
|
|
- and ('process_rule' not in args and not args['process_rule']):
|
|
|
+ and ('process_rule' not in args and not args['process_rule']):
|
|
|
raise ValueError("Data source or Process rule is required")
|
|
|
else:
|
|
|
if args.get('data_source'):
|
|
@@ -1162,7 +1069,7 @@ class DocumentService:
|
|
|
raise ValueError("Process rule rules is invalid")
|
|
|
|
|
|
if 'pre_processing_rules' not in args['process_rule']['rules'] \
|
|
|
- or args['process_rule']['rules']['pre_processing_rules'] is None:
|
|
|
+ or args['process_rule']['rules']['pre_processing_rules'] is None:
|
|
|
raise ValueError("Process rule pre_processing_rules is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
|
|
@@ -1187,21 +1094,21 @@ class DocumentService:
|
|
|
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
|
|
|
|
|
|
if 'segmentation' not in args['process_rule']['rules'] \
|
|
|
- or args['process_rule']['rules']['segmentation'] is None:
|
|
|
+ or args['process_rule']['rules']['segmentation'] is None:
|
|
|
raise ValueError("Process rule segmentation is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
|
|
|
raise ValueError("Process rule segmentation is invalid")
|
|
|
|
|
|
if 'separator' not in args['process_rule']['rules']['segmentation'] \
|
|
|
- or not args['process_rule']['rules']['segmentation']['separator']:
|
|
|
+ or not args['process_rule']['rules']['segmentation']['separator']:
|
|
|
raise ValueError("Process rule segmentation separator is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
|
|
|
raise ValueError("Process rule segmentation separator is invalid")
|
|
|
|
|
|
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
|
|
|
- or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
|
|
+ or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
|
|
raise ValueError("Process rule segmentation max_tokens is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
|
@@ -1237,7 +1144,7 @@ class DocumentService:
|
|
|
raise ValueError("Process rule rules is invalid")
|
|
|
|
|
|
if 'pre_processing_rules' not in args['process_rule']['rules'] \
|
|
|
- or args['process_rule']['rules']['pre_processing_rules'] is None:
|
|
|
+ or args['process_rule']['rules']['pre_processing_rules'] is None:
|
|
|
raise ValueError("Process rule pre_processing_rules is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):
|
|
@@ -1262,21 +1169,21 @@ class DocumentService:
|
|
|
args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())
|
|
|
|
|
|
if 'segmentation' not in args['process_rule']['rules'] \
|
|
|
- or args['process_rule']['rules']['segmentation'] is None:
|
|
|
+ or args['process_rule']['rules']['segmentation'] is None:
|
|
|
raise ValueError("Process rule segmentation is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['segmentation'], dict):
|
|
|
raise ValueError("Process rule segmentation is invalid")
|
|
|
|
|
|
if 'separator' not in args['process_rule']['rules']['segmentation'] \
|
|
|
- or not args['process_rule']['rules']['segmentation']['separator']:
|
|
|
+ or not args['process_rule']['rules']['segmentation']['separator']:
|
|
|
raise ValueError("Process rule segmentation separator is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):
|
|
|
raise ValueError("Process rule segmentation separator is invalid")
|
|
|
|
|
|
if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \
|
|
|
- or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
|
|
+ or not args['process_rule']['rules']['segmentation']['max_tokens']:
|
|
|
raise ValueError("Process rule segmentation max_tokens is required")
|
|
|
|
|
|
if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):
|
|
@@ -1530,16 +1437,12 @@ class SegmentService:
|
|
|
|
|
|
class DatasetCollectionBindingService:
|
|
|
@classmethod
|
|
|
- def get_dataset_collection_binding(
|
|
|
- cls, provider_name: str, model_name: str,
|
|
|
- collection_type: str = 'dataset'
|
|
|
- ) -> DatasetCollectionBinding:
|
|
|
+ def get_dataset_collection_binding(cls, provider_name: str, model_name: str,
|
|
|
+ collection_type: str = 'dataset') -> DatasetCollectionBinding:
|
|
|
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
- filter(
|
|
|
- DatasetCollectionBinding.provider_name == provider_name,
|
|
|
- DatasetCollectionBinding.model_name == model_name,
|
|
|
- DatasetCollectionBinding.type == collection_type
|
|
|
- ). \
|
|
|
+ filter(DatasetCollectionBinding.provider_name == provider_name,
|
|
|
+ DatasetCollectionBinding.model_name == model_name,
|
|
|
+ DatasetCollectionBinding.type == collection_type). \
|
|
|
order_by(DatasetCollectionBinding.created_at). \
|
|
|
first()
|
|
|
|
|
@@ -1555,76 +1458,12 @@ class DatasetCollectionBindingService:
|
|
|
return dataset_collection_binding
|
|
|
|
|
|
@classmethod
|
|
|
- def get_dataset_collection_binding_by_id_and_type(
|
|
|
- cls, collection_binding_id: str,
|
|
|
- collection_type: str = 'dataset'
|
|
|
- ) -> DatasetCollectionBinding:
|
|
|
+ def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,
|
|
|
+ collection_type: str = 'dataset') -> DatasetCollectionBinding:
|
|
|
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
|
|
- filter(
|
|
|
- DatasetCollectionBinding.id == collection_binding_id,
|
|
|
- DatasetCollectionBinding.type == collection_type
|
|
|
- ). \
|
|
|
+ filter(DatasetCollectionBinding.id == collection_binding_id,
|
|
|
+ DatasetCollectionBinding.type == collection_type). \
|
|
|
order_by(DatasetCollectionBinding.created_at). \
|
|
|
first()
|
|
|
|
|
|
return dataset_collection_binding
|
|
|
-
|
|
|
-
|
|
|
-class DatasetPermissionService:
|
|
|
- @classmethod
|
|
|
- def get_dataset_partial_member_list(cls, dataset_id):
|
|
|
- user_list_query = db.session.query(
|
|
|
- DatasetPermission.account_id,
|
|
|
- ).filter(
|
|
|
- DatasetPermission.dataset_id == dataset_id
|
|
|
- ).all()
|
|
|
-
|
|
|
- user_list = []
|
|
|
- for user in user_list_query:
|
|
|
- user_list.append(user.account_id)
|
|
|
-
|
|
|
- return user_list
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def update_partial_member_list(cls, dataset_id, user_list):
|
|
|
- try:
|
|
|
- db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
|
|
|
- permissions = []
|
|
|
- for user in user_list:
|
|
|
- permission = DatasetPermission(
|
|
|
- dataset_id=dataset_id,
|
|
|
- account_id=user['user_id'],
|
|
|
- )
|
|
|
- permissions.append(permission)
|
|
|
-
|
|
|
- db.session.add_all(permissions)
|
|
|
- db.session.commit()
|
|
|
- except Exception as e:
|
|
|
- db.session.rollback()
|
|
|
- raise e
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list):
|
|
|
- if not user.is_dataset_editor:
|
|
|
- raise NoPermissionError('User does not have permission to edit this dataset.')
|
|
|
-
|
|
|
- if user.is_dataset_operator and dataset.permission != requested_permission:
|
|
|
- raise NoPermissionError('Dataset operators cannot change the dataset permissions.')
|
|
|
-
|
|
|
- if user.is_dataset_operator and requested_permission == 'partial_members':
|
|
|
- if not requested_partial_member_list:
|
|
|
- raise ValueError('Partial member list is required when setting to partial members.')
|
|
|
-
|
|
|
- local_member_list = cls.get_dataset_partial_member_list(dataset.id)
|
|
|
- request_member_list = [user['user_id'] for user in requested_partial_member_list]
|
|
|
- if set(local_member_list) != set(request_member_list):
|
|
|
- raise ValueError('Dataset operators cannot change the dataset permissions.')
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def clear_partial_member_list(cls, dataset_id):
|
|
|
- try:
|
|
|
- db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
|
|
|
- db.session.commit()
|
|
|
- except Exception as e:
|
|
|
- db.session.rollback()
|
|
|
- raise e
|