From baf96a123b09a4755a3b4c787efaf256bf1f4cb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 31 Jan 2023 10:42:40 +0800 Subject: [PATCH] fix the div 0 error of pixel_shuffle (#49996) --- paddle/phi/infermeta/unary.cc | 4 ++++ python/paddle/fluid/tests/unittests/test_pixel_shuffle.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 55e895c6622..8cea16f7706 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2533,6 +2533,10 @@ void PixelShuffleInferMeta(const MetaTensor& x, "Input should be a 4-D tensor of format [N, C, H, W] " "or [N, H, W, C], but got %u.", input_dims.size())); + PADDLE_ENFORCE_NE( + upscale_factor, + 0, + phi::errors::InvalidArgument("upscale_factor should not be 0.")); const bool channel_last = (data_format == "NHWC"); diff --git a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py index 196a4ddbd40..9600f5a872c 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py @@ -227,6 +227,13 @@ class TestPixelShuffleError(unittest.TestCase): self.assertRaises(TypeError, error_upscale_factor) + def error_0_upscale_factor(): + with paddle.fluid.dygraph.guard(): + x = paddle.uniform([1, 1, 1, 1], dtype='float64') + pixel_shuffle = F.pixel_shuffle(x, 0) + + self.assertRaises(ValueError, error_0_upscale_factor) + def error_data_format(): with paddle.fluid.dygraph.guard(): x = np.random.random([2, 9, 4, 4]).astype("float64") -- GitLab