未验证 提交 3ab6faa8 编写于 作者: R RedContritio 提交者: GitHub

Fix div 0 error of case11: paddle.nn.functional.max_pool1d/max_pool2d/max_pool3d (#50010)

* add stride check for MaxPool

* add unittests
上级 e4e94a88
......@@ -27,6 +27,11 @@ inline int MaxPoolOutputSize(int input_size,
int filter_size,
int padding,
int stride) {
PADDLE_ENFORCE_NE(
stride,
0,
phi::errors::InvalidArgument(
"The stride of MaxPool shall not be 0, but received %d.", stride));
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
......
......@@ -402,6 +402,11 @@ inline int MaxPoolOutputSize(int input_size,
int filter_size,
int padding,
int stride) {
PADDLE_ENFORCE_NE(
stride,
0,
phi::errors::InvalidArgument(
"The stride of MaxPool shall not be 0, but received %d.", stride));
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
......
......@@ -274,7 +274,7 @@ class TestPool1D_API(unittest.TestCase):
self.check_max_dygraph_return_index_results(place)
class TestPool2DError_API(unittest.TestCase):
class TestPool1DError_API(unittest.TestCase):
def test_error_api(self):
def run1():
with fluid.dygraph.guard():
......@@ -417,6 +417,18 @@ class TestPool2DError_API(unittest.TestCase):
self.assertRaises(ValueError, run_stride_out_of_range)
def run_zero_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1]), dtype='float32'
)
out = F.max_pool1d(
x, 1, stride=0, padding=1, return_mask=True, ceil_mode=True
)
self.assertRaises(ValueError, run_zero_stride)
if __name__ == '__main__':
unittest.main()
......@@ -597,6 +597,18 @@ class TestPool2DError_API(unittest.TestCase):
self.assertRaises(ValueError, run_stride_out_of_range)
def run_zero_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1, 1]), dtype='float32'
)
out = max_pool2d(
x, 1, stride=0, padding=1, return_mask=True, ceil_mode=True
)
self.assertRaises(ValueError, run_zero_stride)
if __name__ == '__main__':
unittest.main()
......@@ -563,6 +563,18 @@ class TestPool3DError_API(unittest.TestCase):
self.assertRaises(ValueError, run_size_out_of_range)
def run_zero_stride():
with fluid.dygraph.guard():
array = np.array([1], dtype=np.float32)
x = paddle.to_tensor(
np.reshape(array, [1, 1, 1, 1, 1]), dtype='float32'
)
out = max_pool3d(
x, 1, stride=0, padding=1, return_mask=True, ceil_mode=True
)
self.assertRaises(ValueError, run_zero_stride)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册