From cd933c0aa215e2848d066e6f8ddf12b3e436d2e5 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Mon, 20 Jul 2020 19:33:00 +0800 Subject: [PATCH] refine error message of randint (#25613) --- python/paddle/fluid/tests/unittests/test_randint_op.py | 1 + python/paddle/tensor/random.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_randint_op.py b/python/paddle/fluid/tests/unittests/test_randint_op.py index 89739a37fd9..5b2d5be346a 100644 --- a/python/paddle/fluid/tests/unittests/test_randint_op.py +++ b/python/paddle/fluid/tests/unittests/test_randint_op.py @@ -57,6 +57,7 @@ class TestRandintOpError(unittest.TestCase): self.assertRaises(TypeError, paddle.randint, 5, shape=np.array([2])) self.assertRaises(TypeError, paddle.randint, 5, dtype='float32') self.assertRaises(ValueError, paddle.randint, 5, 5) + self.assertRaises(ValueError, paddle.randint, -5) class TestRandintOp_attr_tensorlist(OpTest): diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 8ef9dde0880..eac99163e05 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -114,6 +114,10 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): """ if high is None: + if low <= 0: + raise ValueError( + "If high is None, low must be greater than 0, but received low = {0}.". + format(low)) high = low low = 0 if dtype is None: -- GitLab