提交 bccfa485 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2980 Prevent empty column names

Merge pull request !2980 from nhussain/empty_column_b
......@@ -123,25 +123,39 @@ def check_valid_detype(type_):
def check_columns(columns, name):
"""
Validate strings in column_names.
Args:
columns (list): list of column_names.
name (str): name of columns.
Returns:
Exception: when the value is not correct, otherwise nothing.
"""
type_check(columns, (list, str), name)
if isinstance(columns, list):
if not columns:
raise ValueError("Column names should not be empty")
col_names = ["col_{0}".format(i) for i in range(len(columns))]
raise ValueError("{0} should not be empty".format(name))
for i, column_name in enumerate(columns):
if not column_name:
raise ValueError("{0}[{1}] should not be empty".format(name, i))
col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))]
type_check_list(columns, (str,), col_names)
def parse_user_args(method, *args, **kwargs):
"""
Parse user arguments in a function
Parse user arguments in a function.
Args:
method (method): a callable function
*args: user passed args
**kwargs: user passed kwargs
method (method): a callable function.
*args: user passed args.
**kwargs: user passed kwargs.
Returns:
user_filled_args (list): values of what the user passed in for the arguments,
user_filled_args (list): values of what the user passed in for the arguments.
ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed.
"""
sig = inspect.signature(method)
......@@ -160,15 +174,15 @@ def parse_user_args(method, *args, **kwargs):
def type_check_list(args, types, arg_names):
"""
Check the type of each parameter in the list
Check the type of each parameter in the list.
Args:
args (list, tuple): a list or tuple of any variable
types (tuple): tuple of all valid types for arg
arg_names (list, tuple of str): the names of args
args (list, tuple): a list or tuple of any variable.
types (tuple): tuple of all valid types for arg.
arg_names (list, tuple of str): the names of args.
Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing.
"""
type_check(args, (list, tuple,), arg_names)
if len(args) != len(arg_names):
......@@ -179,15 +193,15 @@ def type_check_list(args, types, arg_names):
def type_check(arg, types, arg_name):
"""
Check the type of the parameter
Check the type of the parameter.
Args:
arg : any variable
types (tuple): tuple of all valid types for arg
arg_name (str): the name of arg
arg : any variable.
types (tuple): tuple of all valid types for arg.
arg_name (str): the name of arg.
Returns:
Exception: when the type is not correct, otherwise nothing
Exception: when the type is not correct, otherwise nothing.
"""
# handle special case of booleans being a subclass of ints
print_value = '\"\"' if repr(arg) == repr('') else arg
......@@ -201,13 +215,13 @@ def type_check(arg, types, arg_name):
def check_filename(path):
"""
check the filename in the path
check the filename in the path.
Args:
path (str): the path
path (str): the path.
Returns:
Exception: when error
Exception: when error.
"""
if not isinstance(path, str):
raise TypeError("path: {} is not string".format(path))
......@@ -242,10 +256,10 @@ def check_sampler_shuffle_shard_options(param_dict):
"""
Check for valid shuffle, sampler, num_shards, and shard_id inputs.
Args:
param_dict (dict): param_dict
param_dict (dict): param_dict.
Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error.
"""
shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler')
num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id')
......@@ -268,13 +282,13 @@ def check_sampler_shuffle_shard_options(param_dict):
def check_padding_options(param_dict):
"""
Check for valid padded_sample and num_padded of padded samples
Check for valid padded_sample and num_padded of padded samples.
Args:
param_dict (dict): param_dict
param_dict (dict): param_dict.
Returns:
Exception: ValueError or RuntimeError if error
Exception: ValueError or RuntimeError if error.
"""
columns_list = param_dict.get('columns_list')
......@@ -324,11 +338,11 @@ def check_gnn_list_or_ndarray(param, param_name):
Check if the input parameter is list or numpy.ndarray.
Args:
param (list, nd.ndarray): param
param_name (str): param_name
param (list, nd.ndarray): param.
param_name (str): param_name.
Returns:
Exception: TypeError if error
Exception: TypeError if error.
"""
type_check(param, (list, np.ndarray), param_name)
......
......@@ -380,12 +380,7 @@ def check_bucket_batch_by_length(method):
type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
# check column_names: must be list of string.
if not column_names:
raise ValueError("column_names cannot be empty")
all_string = all(isinstance(item, str) for item in column_names)
if not all_string:
raise TypeError("column_names should be a list of str.")
check_columns(column_names, "column_names")
if element_length_function is None and len(column_names) != 1:
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
......
......@@ -59,7 +59,7 @@ def test_bucket_batch_invalid_input():
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
assert "column_names should be a list of str" in str(info.value)
assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(info.value)
with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as de
from mindspore import log as logger
import mindspore.dataset.transforms.vision.c_transforms as vision
......@@ -173,7 +174,6 @@ def test_numpy_slices_distributed_sampler():
def test_numpy_slices_sequential_sampler():
logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.")
np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]
......@@ -183,6 +183,33 @@ def test_numpy_slices_sequential_sampler():
assert np.equal(data[0], np_data[i % 8]).all()
def test_numpy_slices_invalid_column_names_type():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]
with pytest.raises(TypeError) as err:
de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False)
assert "Argument column_names[0] with value 1 is not of type (<class 'str'>,)." in str(err.value)
def test_numpy_slices_invalid_column_names_string():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]
with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False)
assert "column_names[0] should not be empty" in str(err.value)
def test_numpy_slices_invalid_empty_column_names():
logger.info("Test incorrect column_names input")
np_data = [1, 2, 3]
with pytest.raises(ValueError) as err:
de.NumpySlicesDataset(np_data, column_names=[], shuffle=False)
assert "column_names should not be empty" in str(err.value)
if __name__ == "__main__":
test_numpy_slices_list_1()
test_numpy_slices_list_2()
......@@ -197,3 +224,6 @@ if __name__ == "__main__":
test_numpy_slices_num_samplers()
test_numpy_slices_distributed_sampler()
test_numpy_slices_sequential_sampler()
test_numpy_slices_invalid_column_names_type()
test_numpy_slices_invalid_column_names_string()
test_numpy_slices_invalid_empty_column_names()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册