From 1755a1549987601af10a3e6228bfbe41b796ff2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:00:05 +0800 Subject: [PATCH] fix div 0 error in conv1_transpose (#50000) --- paddle/phi/kernels/funcs/concat_and_split_functor.cc | 5 +++++ .../unittests/test_functional_conv1d_transpose.py | 12 ++++++++++++ 2 files changed, 17 insertions(+) diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cc b/paddle/phi/kernels/funcs/concat_and_split_functor.cc index aa73ba5f68..fd61484eb8 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cc +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cc @@ -37,6 +37,11 @@ struct ConcatFunctor { } int64_t out_rows = rows, out_cols = 0; + PADDLE_ENFORCE_NE( + rows, + 0, + phi::errors::InvalidArgument("The input size should not be 0.")); + std::vector input_cols(input.size()); for (size_t i = 0; i < num; ++i) { int64_t t_cols = input[i].numel() / rows; diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv1d_transpose.py b/python/paddle/fluid/tests/unittests/test_functional_conv1d_transpose.py index 1d4e079f9f..865c848f8b 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv1d_transpose.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv1d_transpose.py @@ -82,5 +82,17 @@ class TestFunctionalConv1DErrorCase2(TestFunctionalConv1DError): self.data_format = "NCL" +class TestFunctionalConv1DErrorCase3(TestFunctionalConv1DError): + def setUp(self): + self.input = np.random.randn(6, 0, 6) + self.filter = np.random.randn(6, 0, 0) + self.bias = None + self.padding = 0 + self.stride = 1 + self.dilation = 1 + self.groups = 1 + self.data_format = "NCL" + + if __name__ == "__main__": unittest.main() -- GitLab