未验证 提交 1451fa51 编写于 作者: R RedContritio 提交者: GitHub

Fix div 0 error of case10: paddle.nn.functional.max_pool2d/max_pool3d (#50012)

* add stride check for PoolOutputSize

* add unittest
上级 c8548af3
...@@ -371,6 +371,13 @@ inline int PoolOutputSize(int input_size, ...@@ -371,6 +371,13 @@ inline int PoolOutputSize(int input_size,
int padding_2, int padding_2,
int stride, int stride,
bool ceil_mode) { bool ceil_mode) {
PADDLE_ENFORCE_NE(
stride,
0,
phi::errors::InvalidArgument(
"The stride of PoolOutputSize shall not be 0, but received %d.",
stride));
int output_size; int output_size;
if (!ceil_mode) { if (!ceil_mode) {
output_size = output_size =
......
...@@ -429,6 +429,16 @@ class TestPool1DError_API(unittest.TestCase): ...@@ -429,6 +429,16 @@ class TestPool1DError_API(unittest.TestCase):
self.assertRaises(ValueError, run_zero_stride) self.assertRaises(ValueError, run_zero_stride)
def run_zero_tuple_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1]), dtype='float32'
)
out = F.max_pool1d(x, 1, stride=(0))
self.assertRaises(ValueError, run_zero_tuple_stride)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -609,6 +609,18 @@ class TestPool2DError_API(unittest.TestCase): ...@@ -609,6 +609,18 @@ class TestPool2DError_API(unittest.TestCase):
self.assertRaises(ValueError, run_zero_stride) self.assertRaises(ValueError, run_zero_stride)
def run_zero_tuple_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1, 1]), dtype='float32'
)
out = max_pool2d(
x, 1, stride=(0, 0), return_mask=False, data_format='NHWC'
)
self.assertRaises(ValueError, run_zero_tuple_stride)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -575,6 +575,16 @@ class TestPool3DError_API(unittest.TestCase): ...@@ -575,6 +575,16 @@ class TestPool3DError_API(unittest.TestCase):
self.assertRaises(ValueError, run_zero_stride) self.assertRaises(ValueError, run_zero_stride)
def run_zero_tuple_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1, 1, 1]), dtype='float32'
)
out = max_pool3d(x, 1, stride=(0, 0, 0), ceil_mode=False)
self.assertRaises(ValueError, run_zero_tuple_stride)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册