提交 d08a89ab 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2737 makes 0 an invaild bucket size

Merge pull request !2737 from Peilin/zero-input-bucket-check
......@@ -643,9 +643,9 @@ def check_bucket_batch_by_length(method):
if not all_int:
raise TypeError("bucket_batch_sizes should be a list of int.")
all_non_negative = all(item >= 0 for item in bucket_batch_sizes)
all_non_negative = all(item > 0 for item in bucket_batch_sizes)
if not all_non_negative:
raise ValueError("bucket_batch_sizes cannot contain any negative numbers.")
raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
if param_dict.get('pad_info') is not None:
check_type(param_dict["pad_info"], "pad_info", dict)
......
......@@ -51,6 +51,7 @@ def test_bucket_batch_invalid_input():
bucket_batch_sizes = [1, 1, 1, 1]
invalid_bucket_batch_sizes = ["1", "2", "3", "4"]
negative_bucket_batch_sizes = [1, 2, 3, -4]
zero_bucket_batch_sizes = [0, 1, 2, 3]
with pytest.raises(TypeError) as info:
_ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes)
......@@ -82,7 +83,11 @@ def test_bucket_batch_invalid_input():
with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes)
assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value)
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value)
with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, zero_bucket_batch_sizes)
assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value)
with pytest.raises(ValueError) as info:
_ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册