diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 9857608c19ba8a10ea116a6a794adb9b61fc6329..2a0bef3b422f1fc7be57755f7e5993fd7a3af4bf 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -643,9 +643,9 @@ def check_bucket_batch_by_length(method): if not all_int: raise TypeError("bucket_batch_sizes should be a list of int.") - all_non_negative = all(item >= 0 for item in bucket_batch_sizes) + all_non_negative = all(item > 0 for item in bucket_batch_sizes) if not all_non_negative: - raise ValueError("bucket_batch_sizes cannot contain any negative numbers.") + raise ValueError("bucket_batch_sizes should be a list of positive numbers.") if param_dict.get('pad_info') is not None: check_type(param_dict["pad_info"], "pad_info", dict) diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index bca30723e984aef6e674d4087c50ac16f4bf724a..4436f98e53462e7d423dc48958d98721b9b23e98 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -51,6 +51,7 @@ def test_bucket_batch_invalid_input(): bucket_batch_sizes = [1, 1, 1, 1] invalid_bucket_batch_sizes = ["1", "2", "3", "4"] negative_bucket_batch_sizes = [1, 2, 3, -4] + zero_bucket_batch_sizes = [0, 1, 2, 3] with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) @@ -82,7 +83,11 @@ def test_bucket_batch_invalid_input(): with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, negative_bucket_batch_sizes) - assert "bucket_batch_sizes cannot contain any negative numbers" in str(info.value) + assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value) + + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, zero_bucket_batch_sizes) + assert "bucket_batch_sizes should be a list of positive numbers" in str(info.value) with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_boundaries)