From 7a0fdeb9f33d680c15ef7a8b746fdf47550bcf57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 7 Feb 2023 20:08:19 +0800 Subject: [PATCH] fix div 0 error in conv1/2/3 (#49999) --- paddle/phi/infermeta/binary.cc | 7 +++++++ .../tests/unittests/test_functional_conv1d.py | 12 ++++++++++++ .../tests/unittests/test_functional_conv2d.py | 14 ++++++++++++++ .../tests/unittests/test_functional_conv3d.py | 14 ++++++++++++++ 4 files changed, 47 insertions(+) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 3ca56e0602..7551c51d64 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -467,6 +467,13 @@ void ConvInferMeta(const MetaTensor& input, const bool channel_last = (config.is_run_mkldnn_kernel == false) && (data_format == "NHWC" || data_format == "NDHWC"); + for (int i = 0; i < 2; ++i) { + PADDLE_ENFORCE_NE(in_dims[i], + 0, + phi::errors::InvalidArgument( + "The size of Op(Conv) inputs should not be 0.")); + } + PADDLE_ENFORCE_EQ( in_dims.size() == 4 || in_dims.size() == 5, true, diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv1d.py b/python/paddle/fluid/tests/unittests/test_functional_conv1d.py index 0bd7fa1878..d050c61639 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv1d.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv1d.py @@ -70,5 +70,17 @@ class TestFunctionalConv1DErrorCase1(TestFunctionalConv1DError): self.data_format = "NCL" +class TestFunctionalConv1DErrorCase2(TestFunctionalConv1DError): + def setUp(self): + self.input = np.random.randn(0, 0, 0) + self.filter = np.random.randn(1, 0, 0) + self.bias = None + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.data_format = "NCL" + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py index 00cc6c07aa..c78f6c35b0 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv2d.py @@ -569,6 +569,20 @@ class TestFunctionalConv2DErrorCase13(TestFunctionalConv2DErrorCase12): self.data_format = "NCHW" +class TestFunctionalConv2DErrorCase14(TestFunctionalConv2DErrorCase12): + def setUp(self): + self.input = np.random.randn(0, 0, 0, 0) + self.filter = np.random.randn(1, 0, 0, 0) + self.num_filters = 0 + self.filter_size = 0 + self.bias = None + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.data_format = "NCHW" + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv3d.py b/python/paddle/fluid/tests/unittests/test_functional_conv3d.py index 62322f8e3d..5e867036dd 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv3d.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv3d.py @@ -544,6 +544,20 @@ class TestFunctionalConv3DErrorCase12(TestFunctionalConv3DErrorCase11): self.data_format = "NCDHW" +class TestFunctionalConv3DErrorCase13(TestFunctionalConv3DErrorCase11): + def setUp(self): + self.input = np.random.randn(0, 0, 0, 0, 0) + self.filter = np.random.randn(1, 0, 0, 0, 0) + self.num_filters = 1 + self.filter_size = 1 + self.bias = None + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.data_format = "NCDHW" + + if __name__ == "__main__": paddle.enable_static() unittest.main() -- GitLab