diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 5bfd7656d31d56574a73507c9e3e2f8c8bff2fa9..76a5f1b8192617c72992a9e928e6005966536153 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1065,12 +1065,12 @@ def check_split(method): if all_int: all_positive = all(item > 0 for item in sizes) if not all_positive: - raise ValueError("sizes is a list of int, but there should be no negative numbers.") + raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.") if all_float: all_valid_percentages = all(0 < item <= 1 for item in sizes) if not all_valid_percentages: - raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].") + raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].") epsilon = 0.00001 if not abs(sum(sizes) - 1) < epsilon: diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index a51e8524545921c21e5f6ab42485aa785ed63477..7dd140f60e31c298b1e739b5ad9f66b8d2dbcc17 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -38,7 +38,7 @@ def split_with_invalid_inputs(d): with pytest.raises(ValueError) as info: _, _ = d.split([-1, 6]) - assert "there should be no negative numbers" in str(info.value) + assert "there should be no negative or zero numbers" in str(info.value) with pytest.raises(RuntimeError) as info: _, _ = d.split([3, 1]) @@ -54,11 +54,11 @@ def split_with_invalid_inputs(d): with pytest.raises(ValueError) as info: _, _ = d.split([-0.5, 0.5]) - assert "there should be no numbers outside the range [0, 1]" in str(info.value) + assert "there should be no numbers outside the range (0, 1]" in str(info.value) with pytest.raises(ValueError) as info: _, _ = d.split([1.5, 0.5]) - assert "there should be no numbers outside the range [0, 1]" in str(info.value) + assert "there should be no numbers outside the range (0, 1]" in str(info.value) with pytest.raises(ValueError) as info: _, _ = d.split([0.5, 0.6])