From 82edc65ba2e533c25cf6cd34117f43268043ba44 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Tue, 31 Jan 2023 10:50:16 +0800 Subject: [PATCH] =?UTF-8?q?Fix=20=E7=A9=BA=E6=8C=87=E9=92=88=20(Null=20poi?= =?UTF-8?q?nter)=20of=20case=2014=20paddle.atan2=20(#49973)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add elements count check in atan2 * add unittest and pre-check in inferMeta * add dimension check --- paddle/phi/infermeta/binary.cc | 20 +++++++++++++++++++ paddle/phi/kernels/impl/atan2_kernel_impl.h | 8 ++++++++ .../fluid/tests/unittests/test_atan2_op.py | 12 +++++++++++ 3 files changed, 40 insertions(+) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 561938adca..3ca56e0602 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -142,6 +142,26 @@ void KLDivInferMeta(const MetaTensor& x, } void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + PADDLE_ENFORCE_EQ( + x_dims.size(), + y_dims.size(), + phi::errors::InvalidArgument("The rank (%d) of X shall be same as " + "rank (%d) of Y.", + x_dims.size(), + y_dims.size())); + + if (x_dims.size() > 0) + PADDLE_ENFORCE_LE(x_dims[0], + y_dims[0], + phi::errors::InvalidArgument( + "The count (%d) of elements of X shall not " + "greater than count (%d) of elements of Y.", + x_dims[0], + y_dims[0])); + out->share_meta(x); if (x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64 || y.dtype() == DataType::INT32 || y.dtype() == DataType::INT64) { diff --git a/paddle/phi/kernels/impl/atan2_kernel_impl.h b/paddle/phi/kernels/impl/atan2_kernel_impl.h index 2cae914e2f..b7799a7770 100644 --- a/paddle/phi/kernels/impl/atan2_kernel_impl.h +++ b/paddle/phi/kernels/impl/atan2_kernel_impl.h @@ -77,6 +77,14 @@ void Atan2Kernel(const Context& ctx, auto x_data = x.data(); auto y_data = y.data(); + PADDLE_ENFORCE_LE( + numel, + y.numel(), + phi::errors::InvalidArgument("The count (%d) of elements of X shall not " + "greater than count (%d) of elements of Y.", + numel, + y.numel())); + auto* out_data = ctx.template Alloc::type>( out, size_t(x.numel() * sizeof(typename Atan2Out::type))); diff --git a/python/paddle/fluid/tests/unittests/test_atan2_op.py b/python/paddle/fluid/tests/unittests/test_atan2_op.py index 77ad77e325..6b62b25ac5 100644 --- a/python/paddle/fluid/tests/unittests/test_atan2_op.py +++ b/python/paddle/fluid/tests/unittests/test_atan2_op.py @@ -130,6 +130,18 @@ class TestAtan2API(unittest.TestCase): run(place) +class TestAtan2Error(unittest.TestCase): + def test_mismatch(self): + paddle.enable_static() + + def test_mismatch_numel(): + X = paddle.fluid.data('X', (1,), dtype=np.float64) + Y = paddle.fluid.data('Y', (0,), dtype=np.float64) + out = paddle.atan2(X, Y) + + self.assertRaises(ValueError, test_mismatch_numel) + + if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab