未验证 提交 baf96a12 编写于 作者: 张春乔 提交者: GitHub

fix the div 0 error of pixel_shuffle (#49996)

上级 4976153d
...@@ -2533,6 +2533,10 @@ void PixelShuffleInferMeta(const MetaTensor& x, ...@@ -2533,6 +2533,10 @@ void PixelShuffleInferMeta(const MetaTensor& x,
"Input should be a 4-D tensor of format [N, C, H, W] " "Input should be a 4-D tensor of format [N, C, H, W] "
"or [N, H, W, C], but got %u.", "or [N, H, W, C], but got %u.",
input_dims.size())); input_dims.size()));
PADDLE_ENFORCE_NE(
upscale_factor,
0,
phi::errors::InvalidArgument("upscale_factor should not be 0."));
const bool channel_last = (data_format == "NHWC"); const bool channel_last = (data_format == "NHWC");
......
...@@ -227,6 +227,13 @@ class TestPixelShuffleError(unittest.TestCase): ...@@ -227,6 +227,13 @@ class TestPixelShuffleError(unittest.TestCase):
self.assertRaises(TypeError, error_upscale_factor) self.assertRaises(TypeError, error_upscale_factor)
def error_0_upscale_factor():
with paddle.fluid.dygraph.guard():
x = paddle.uniform([1, 1, 1, 1], dtype='float64')
pixel_shuffle = F.pixel_shuffle(x, 0)
self.assertRaises(ValueError, error_0_upscale_factor)
def error_data_format(): def error_data_format():
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
x = np.random.random([2, 9, 4, 4]).astype("float64") x = np.random.random([2, 9, 4, 4]).astype("float64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册