提交 bccfa485 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2980 Prevent empty column names

Merge pull request !2980 from nhussain/empty_column_b
...@@ -123,25 +123,39 @@ def check_valid_detype(type_): ...@@ -123,25 +123,39 @@ def check_valid_detype(type_):
def check_columns(columns, name): def check_columns(columns, name):
"""
Validate strings in column_names.
Args:
columns (list): list of column_names.
name (str): name of columns.
Returns:
Exception: when the value is not correct, otherwise nothing.
"""
type_check(columns, (list, str), name) type_check(columns, (list, str), name)
if isinstance(columns, list): if isinstance(columns, list):
if not columns: if not columns:
raise ValueError("Column names should not be empty") raise ValueError("{0} should not be empty".format(name))
col_names = ["col_{0}".format(i) for i in range(len(columns))] for i, column_name in enumerate(columns):
if not column_name:
raise ValueError("{0}[{1}] should not be empty".format(name, i))
col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
type_check_list(columns, (str,), col_names) type_check_list(columns, (str,), col_names)
def parse_user_args(method, *args, **kwargs): def parse_user_args(method, *args, **kwargs):
""" """
Parse user arguments in a function Parse user arguments in a function.
Args: Args:
method (method): a callable function method (method): a callable function.
*args: user passed args *args: user passed args.
**kwargs: user passed kwargs **kwargs: user passed kwargs.
Returns: Returns:
user_filled_args (list): values of what the user passed in for the arguments, user_filled_args (list): values of what the user passed in for the arguments.
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
""" """
sig = inspect.signature(method) sig = inspect.signature(method)
...@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs): ...@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs):
def type_check_list(args, types, arg_names): def type_check_list(args, types, arg_names):
""" """
Check the type of each parameter in the list Check the type of each parameter in the list.
Args: Args:
args (list, tuple): a list or tuple of any variable args (list, tuple): a list or tuple of any variable.
types (tuple): tuple of all valid types for arg types (tuple): tuple of all valid types for arg.
arg_names (list, tuple of str): the names of args arg_names (list, tuple of str): the names of args.
Returns: Returns:
Exception: when the type is not correct, otherwise nothing Exception: when the type is not correct, otherwise nothing.
""" """
type_check(args, (list, tuple,), arg_names) type_check(args, (list, tuple,), arg_names)
if len(args) != len(arg_names): if len(args) != len(arg_names):
...@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names): ...@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names):
def type_check(arg, types, arg_name): def type_check(arg, types, arg_name):
""" """
Check the type of the parameter Check the type of the parameter.
Args: Args:
arg : any variable arg : any variable.
types (tuple): tuple of all valid types for arg types (tuple): tuple of all valid types for arg.
arg_name (str): the name of arg arg_name (str): the name of arg.
Returns: Returns:
Exception: when the type is not correct, otherwise nothing Exception: when the type is not correct, otherwise nothing.
""" """
# handle special case of booleans being a subclass of ints # handle special case of booleans being a subclass of ints
print_value = '\"\"' if repr(arg) == repr('') else arg print_value = '\"\"' if repr(arg) == repr('') else arg
...@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name): ...@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name):
def check_filename(path): def check_filename(path):
""" """
check the filename in the path check the filename in the path.
Args: Args:
path (str): the path path (str): the path.
Returns: Returns:
Exception: when error Exception: when error.
""" """
if not isinstance(path, str): if not isinstance(path, str):
raise TypeError("path: {} is not string".format(path)) raise TypeError("path: {} is not string".format(path))
...@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict): ...@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict):
""" """
Check for valid shuffle, sampler, num_shards, and shard_id inputs. Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Args: Args:
param_dict (dict): param_dict param_dict (dict): param_dict.
Returns: Returns:
Exception: ValueError or RuntimeError if error Exception: ValueError or RuntimeError if error.
""" """
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
...@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict): ...@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict):
def check_padding_options(param_dict): def check_padding_options(param_dict):
""" """
Check for valid padded_sample and num_padded of padded samples Check for valid padded_sample and num_padded of padded samples.
Args: Args:
param_dict (dict): param_dict param_dict (dict): param_dict.
Returns: Returns:
Exception: ValueError or RuntimeError if error Exception: ValueError or RuntimeError if error.
""" """
columns_list = param_dict.get('columns_list') columns_list = param_dict.get('columns_list')
...@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name): ...@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name):
Check if the input parameter is list or numpy.ndarray. Check if the input parameter is list or numpy.ndarray.
Args: Args:
param (list, nd.ndarray): param param (list, nd.ndarray): param.
param_name (str): param_name param_name (str): param_name.
Returns: Returns:
Exception: TypeError if error Exception: TypeError if error.
""" """
type_check(param, (list, np.ndarray), param_name) type_check(param, (list, np.ndarray), param_name)
......
...@@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method): ...@@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method):
type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list) type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
# check column_names: must be list of string. # check column_names: must be list of string.
if not column_names: check_columns(column_names, "column_names")
raise ValueError("column_names cannot be empty")
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")
if element_length_function is None and len(column_names) != 1: if element_length_function is None and len(column_names) != 1:
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
......
...@@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input(): ...@@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input():
with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
assert "column_names should be a list of str" in str(info.value) assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(info.value)
with pytest.raises(ValueError) as info: with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.dataset as de import mindspore.dataset as de
from mindspore import log as logger from mindspore import log as logger
import mindspore.dataset.transforms.vision.c_transforms as vision import mindspore.dataset.transforms.vision.c_transforms as vision
...@@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler(): ...@@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler():
def test_numpy_slices_sequential_sampler(): def test_numpy_slices_sequential_sampler():
logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
...@@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler(): ...@@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler():
assert np.equal(data[0], np_data[i % 8]).all() assert np.equal(data[0], np_data[i % 8]).all()
def test_numpy_slices_invalid_column_names_type():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]
with pytest.raises(TypeError) as err:
de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False)
assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(err.value)
def test_numpy_slices_invalid_column_names_string():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]
with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False)
assert "column_names[0] should not be empty" in str(err.value)
def test_numpy_slices_invalid_empty_column_names():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]
with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, column_names=[], shuffle=False)
assert "column_names should not be empty" in str(err.value)
if __name__ == "__main__": if __name__ == "__main__":
test_numpy_slices_list_1() test_numpy_slices_list_1()
test_numpy_slices_list_2() test_numpy_slices_list_2()
...@@ -197,3 +224,6 @@ if __name__ == "__main__": ...@@ -197,3 +224,6 @@ if __name__ == "__main__":
test_numpy_slices_num_samplers() test_numpy_slices_num_samplers()
test_numpy_slices_distributed_sampler() test_numpy_slices_distributed_sampler()
test_numpy_slices_sequential_sampler() test_numpy_slices_sequential_sampler()
test_numpy_slices_invalid_column_names_type()
test_numpy_slices_invalid_column_names_string()
test_numpy_slices_invalid_empty_column_names()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册