From 68e2097897c0b817707eee29c248f66b00ff7205 Mon Sep 17 00:00:00 2001 From: liyong Date: Sun, 28 Jun 2020 10:20:04 +0800 Subject: [PATCH] fix split erroer message --- mindspore/dataset/engine/validators.py | 4 ++-- tests/ut/python/dataset/test_split.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 5bfd7656d..76a5f1b81 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1065,12 +1065,12 @@ def check_split(method): if all_int: all_positive = all(item > 0 for item in sizes) if not all_positive: - raise ValueError("sizes is a list of int, but there should be no negative numbers.") + raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.") if all_float: all_valid_percentages = all(0 < item <= 1 for item in sizes) if not all_valid_percentages: - raise ValueError("sizes is a list of float, but there should be no numbers outside the range [0, 1].") + raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].") epsilon = 0.00001 if not abs(sum(sizes) - 1) < epsilon: diff --git a/tests/ut/python/dataset/test_split.py b/tests/ut/python/dataset/test_split.py index a51e85245..7dd140f60 100644 --- a/tests/ut/python/dataset/test_split.py +++ b/tests/ut/python/dataset/test_split.py @@ -38,7 +38,7 @@ def split_with_invalid_inputs(d): with pytest.raises(ValueError) as info: _, _ = d.split([-1, 6]) - assert "there should be no negative numbers" in str(info.value) + assert "there should be no negative or zero numbers" in str(info.value) with pytest.raises(RuntimeError) as info: _, _ = d.split([3, 1]) @@ -54,11 +54,11 @@ def split_with_invalid_inputs(d): with pytest.raises(ValueError) as info: _, _ = d.split([-0.5, 0.5]) - assert "there should be no numbers outside the range [0, 1]" in str(info.value) + assert "there should be no numbers outside the range (0, 1]" in str(info.value) with pytest.raises(ValueError) as info: _, _ = d.split([1.5, 0.5]) - assert "there should be no numbers outside the range [0, 1]" in str(info.value) + assert "there should be no numbers outside the range (0, 1]" in str(info.value) with pytest.raises(ValueError) as info: _, _ = d.split([0.5, 0.6]) -- GitLab