提交 29aa5899 编写于 作者: P peilin-wang

added check for invalid type for boolean args

上级 709dfd7e
...@@ -606,8 +606,15 @@ def check_bucket_batch_by_length(method): ...@@ -606,8 +606,15 @@ def check_bucket_batch_by_length(method):
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
check_param_type(nreq_param_list, param_dict, list) check_param_type(nreq_param_list, param_dict, list)
nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
check_param_type(nbool_param_list, param_dict, bool)
# check column_names: must be list of string. # check column_names: must be list of string.
column_names = param_dict.get("column_names") column_names = param_dict.get("column_names")
if not column_names:
raise ValueError("column_names cannot be empty")
all_string = all(isinstance(item, str) for item in column_names) all_string = all(isinstance(item, str) for item in column_names)
if not all_string: if not all_string:
raise TypeError("column_names should be a list of str.") raise TypeError("column_names should be a list of str.")
......
...@@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input(): ...@@ -53,6 +53,9 @@ def test_bucket_batch_invalid_input():
negative_bucket_batch_sizes = [1, 2, 3, -4] negative_bucket_batch_sizes = [1, 2, 3, -4]
zero_bucket_batch_sizes = [0, 1, 2, 3] zero_bucket_batch_sizes = [0, 1, 2, 3]
invalid_type_pad_to_bucket_boundary = ""
invalid_type_drop_remainder = ""
with pytest.raises(TypeError) as info: with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
assert "column_names should be a list of str" in str(info.value) assert "column_names should be a list of str" in str(info.value)
...@@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input(): ...@@ -93,6 +96,16 @@ def test_bucket_batch_invalid_input():
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries) _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)
assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value) assert "bucket_batch_sizes must contain one element more than bucket_boundaries" in str(info.value)
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, invalid_type_pad_to_bucket_boundary)
assert "Wrong input type for pad_to_bucket_boundary, should be <class 'bool'>" in str(info.value)
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes,
None, None, False, invalid_type_drop_remainder)
assert "Wrong input type for drop_remainder, should be <class 'bool'>" in str(info.value)
def test_bucket_batch_multi_bucket_no_padding(): def test_bucket_batch_multi_bucket_no_padding():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"]) dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册