diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py index 7a93fcf174ba31eede63c47aae4f33f7a010df13..f7b334635930606c5e9ccf79eb1946731c10f6f1 100644 --- a/mindspore/dataset/core/validator_helpers.py +++ b/mindspore/dataset/core/validator_helpers.py @@ -123,25 +123,39 @@ def check_valid_detype(type_): def check_columns(columns, name): + """ + Validate strings in column_names. + + Args: + columns (list): list of column_names. + name (str): name of columns. + + Returns: + Exception: when the value is not correct, otherwise nothing. + """ type_check(columns, (list, str), name) if isinstance(columns, list): if not columns: - raise ValueError("Column names should not be empty") - col_names = ["col_{0}".format(i) for i in range(len(columns))] + raise ValueError("{0} should not be empty".format(name)) + for i, column_name in enumerate(columns): + if not column_name: + raise ValueError("{0}[{1}] should not be empty".format(name, i)) + + col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] type_check_list(columns, (str,), col_names) def parse_user_args(method, *args, **kwargs): """ - Parse user arguments in a function + Parse user arguments in a function. Args: - method (method): a callable function - *args: user passed args - **kwargs: user passed kwargs + method (method): a callable function. + *args: user passed args. + **kwargs: user passed kwargs. Returns: - user_filled_args (list): values of what the user passed in for the arguments, + user_filled_args (list): values of what the user passed in for the arguments. ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. """ sig = inspect.signature(method) @@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs): def type_check_list(args, types, arg_names): """ - Check the type of each parameter in the list + Check the type of each parameter in the list. Args: - args (list, tuple): a list or tuple of any variable - types (tuple): tuple of all valid types for arg - arg_names (list, tuple of str): the names of args + args (list, tuple): a list or tuple of any variable. + types (tuple): tuple of all valid types for arg. + arg_names (list, tuple of str): the names of args. Returns: - Exception: when the type is not correct, otherwise nothing + Exception: when the type is not correct, otherwise nothing. """ type_check(args, (list, tuple,), arg_names) if len(args) != len(arg_names): @@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names): def type_check(arg, types, arg_name): """ - Check the type of the parameter + Check the type of the parameter. Args: - arg : any variable - types (tuple): tuple of all valid types for arg - arg_name (str): the name of arg + arg : any variable. + types (tuple): tuple of all valid types for arg. + arg_name (str): the name of arg. Returns: - Exception: when the type is not correct, otherwise nothing + Exception: when the type is not correct, otherwise nothing. """ # handle special case of booleans being a subclass of ints print_value = '\"\"' if repr(arg) == repr('') else arg @@ -201,13 +215,13 @@ def type_check(arg, types, arg_name): def check_filename(path): """ - check the filename in the path + check the filename in the path. Args: - path (str): the path + path (str): the path. Returns: - Exception: when error + Exception: when error. """ if not isinstance(path, str): raise TypeError("path: {} is not string".format(path)) @@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict): """ Check for valid shuffle, sampler, num_shards, and shard_id inputs. Args: - param_dict (dict): param_dict + param_dict (dict): param_dict. Returns: - Exception: ValueError or RuntimeError if error + Exception: ValueError or RuntimeError if error. """ shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') @@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict): def check_padding_options(param_dict): """ - Check for valid padded_sample and num_padded of padded samples + Check for valid padded_sample and num_padded of padded samples. Args: - param_dict (dict): param_dict + param_dict (dict): param_dict. Returns: - Exception: ValueError or RuntimeError if error + Exception: ValueError or RuntimeError if error. """ columns_list = param_dict.get('columns_list') @@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name): Check if the input parameter is list or numpy.ndarray. Args: - param (list, nd.ndarray): param - param_name (str): param_name + param (list, nd.ndarray): param. + param_name (str): param_name. Returns: - Exception: TypeError if error + Exception: TypeError if error. """ type_check(param, (list, np.ndarray), param_name) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 7edf381b2c6e00e5fbedb6b0f6c02c443e38db41..ab7cc6ac54387060c1cdcde3aaa801beace06c4b 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method): type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list) # check column_names: must be list of string. - if not column_names: - raise ValueError("column_names cannot be empty") - - all_string = all(isinstance(item, str) for item in column_names) - if not all_string: - raise TypeError("column_names should be a list of str.") + check_columns(column_names, "column_names") if element_length_function is None and len(column_names) != 1: raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index a30b5827cb53bc8a7113adb23bc27bf1dc758fde..5da7b1636da3090a79957e890e74d364b7d86573 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input(): with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) - assert "column_names should be a list of str" in str(info.value) + assert "Argument column_names[0] with value 1 is not of type (,)." in str(info.value) with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) diff --git a/tests/ut/python/dataset/test_dataset_numpy_slices.py b/tests/ut/python/dataset/test_dataset_numpy_slices.py index 4cd4e26a337bba4513335f18fede209b358c39da..fe773b0328f5a910375f425f53cbfe9e7225e8f1 100644 --- a/tests/ut/python/dataset/test_dataset_numpy_slices.py +++ b/tests/ut/python/dataset/test_dataset_numpy_slices.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as de from mindspore import log as logger import mindspore.dataset.transforms.vision.c_transforms as vision @@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler(): def test_numpy_slices_sequential_sampler(): - logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] @@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler(): assert np.equal(data[0], np_data[i % 8]).all() +def test_numpy_slices_invalid_column_names_type(): + logger.info("Test incorrect column_names input") + np_data = [1, 2, 3] + + with pytest.raises(TypeError) as err: + de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False) + assert "Argument column_names[0] with value 1 is not of type (,)." in str(err.value) + + +def test_numpy_slices_invalid_column_names_string(): + logger.info("Test incorrect column_names input") + np_data = [1, 2, 3] + + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False) + assert "column_names[0] should not be empty" in str(err.value) + + +def test_numpy_slices_invalid_empty_column_names(): + logger.info("Test incorrect column_names input") + np_data = [1, 2, 3] + + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, column_names=[], shuffle=False) + assert "column_names should not be empty" in str(err.value) + + if __name__ == "__main__": test_numpy_slices_list_1() test_numpy_slices_list_2() @@ -197,3 +224,6 @@ if __name__ == "__main__": test_numpy_slices_num_samplers() test_numpy_slices_distributed_sampler() test_numpy_slices_sequential_sampler() + test_numpy_slices_invalid_column_names_type() + test_numpy_slices_invalid_column_names_string() + test_numpy_slices_invalid_empty_column_names()