From a48ef36002c9f52bb7b4b6f6c3426cc913433ce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:04:07 +0800 Subject: [PATCH] fix the NullPointerError of median (#50017) --- python/paddle/fluid/tests/unittests/test_median.py | 1 + python/paddle/tensor/stat.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_median.py b/python/paddle/fluid/tests/unittests/test_median.py index a62e722dd0..1f90faeac0 100644 --- a/python/paddle/fluid/tests/unittests/test_median.py +++ b/python/paddle/fluid/tests/unittests/test_median.py @@ -86,6 +86,7 @@ class TestMedian(unittest.TestCase): x = paddle.arange(12).reshape([3, 4]) self.assertRaises(ValueError, paddle.median, x, 1.0) self.assertRaises(ValueError, paddle.median, x, 2) + self.assertRaises(ValueError, paddle.median, paddle.to_tensor([])) if __name__ == '__main__': diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index e23f28aa76..cc94aee415 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -406,6 +406,9 @@ def median(x, axis=None, keepdim=False, name=None): if not isinstance(x, Variable): raise TypeError("In median, the input x should be a Tensor.") + if x.size == 0: + raise ValueError("In median, the size of input x should not be 0.") + if len(x.shape) == 0: return x.clone() -- GitLab