diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index f7b334635930606c5e9ccf79eb1946731c10f6f1..d0c17875b7e3c67c7ccc148c7bbe23bf1d11f769 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -95,7 +95,7 @@ def check_uint32(value, arg_name=""): def check_pos_int32(value, arg_name=""): type_check(value, (int,), arg_name) - check_value(value, [POS_INT_MIN, INT32_MAX]) + check_value(value, [POS_INT_MIN, INT32_MAX], arg_name) def check_uint64(value, arg_name=""): @@ -143,6 +143,8 @@ def check_columns(columns, name): col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] type_check_list(columns, (str,), col_names) + if len(set(columns)) != len(columns): + raise ValueError("Every column name should not be same with others in column_names.") def parse_user_args(method, *args, **kwargs): diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ae0dc6789e1d63db1ff344396081580bdf3b0ec1..cb6376ebd5504dafeda5b6c397452ddc234cf1cf 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_split, check_bucket_batch_by_length, check_cluedataset + check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -939,6 +939,7 @@ class Dataset: raise TypeError("apply_func must return a dataset.") return dataset + @check_positive_int32 def device_que(self, prefetch_size=None): """ Return a transferredDataset that transfer data through device. @@ -956,6 +957,7 @@ class Dataset: """ return self.to_device() + @check_positive_int32 def to_device(self, num_batch=None): """ Transfer data through CPU, GPU or Ascend devices. @@ -973,7 +975,7 @@ class Dataset: Raises: TypeError: If device_type is empty. ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. - ValueError: If num_batch is None or 0 or larger than int_max. + ValueError: If num_batch is negative or larger than int_max. RuntimeError: If dataset is unknown. RuntimeError: If distribution file path is given but failed to read. """ diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index ab7cc6ac54387060c1cdcde3aaa801beace06c4b..8f127e03136ceb11c5f14f7b1a7aaf51ddb1faf8 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -25,7 +25,7 @@ from mindspore._c_expression import typing from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ - check_columns, check_positive + check_columns, check_positive, check_pos_int32 from . import datasets from . import samplers @@ -593,6 +593,25 @@ def check_take(method): return new_method +def check_positive_int32(method): + """check whether the input argument is positive and int, only works for functions with one input.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [count], param_dict = parse_user_args(method, *args, **kwargs) + para_name = None + for key in list(param_dict.keys()): + if key not in ['self', 'cls']: + para_name = key + # Need to get default value of param + if count is not None: + check_pos_int32(count, para_name) + + return method(self, *args, **kwargs) + + return new_method + + def check_zip(method): """check the input arguments of zip."""