diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 561938adca80a22cc3700baab3dc58c8bf9a6321..3ca56e0602c1d2413ef5ceef36ecbffe047ecf8a 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -142,6 +142,26 @@ void KLDivInferMeta(const MetaTensor& x, } void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + PADDLE_ENFORCE_EQ( + x_dims.size(), + y_dims.size(), + phi::errors::InvalidArgument("The rank (%d) of X shall be same as " + "rank (%d) of Y.", + x_dims.size(), + y_dims.size())); + + if (x_dims.size() > 0) + PADDLE_ENFORCE_LE(x_dims[0], + y_dims[0], + phi::errors::InvalidArgument( + "The count (%d) of elements of X shall not " + "greater than count (%d) of elements of Y.", + x_dims[0], + y_dims[0])); + out->share_meta(x); if (x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64 || y.dtype() == DataType::INT32 || y.dtype() == DataType::INT64) { diff --git a/paddle/phi/kernels/impl/atan2_kernel_impl.h b/paddle/phi/kernels/impl/atan2_kernel_impl.h index 2cae914e2f61555377f7a41b3d89cdbb2b589247..b7799a777046f4d63034df1aff23c68915d8d7a8 100644 --- a/paddle/phi/kernels/impl/atan2_kernel_impl.h +++ b/paddle/phi/kernels/impl/atan2_kernel_impl.h @@ -77,6 +77,14 @@ void Atan2Kernel(const Context& ctx, auto x_data = x.data(); auto y_data = y.data(); + PADDLE_ENFORCE_LE( + numel, + y.numel(), + phi::errors::InvalidArgument("The count (%d) of elements of X shall not " + "greater than count (%d) of elements of Y.", + numel, + y.numel())); + auto* out_data = ctx.template Alloc::type>( out, size_t(x.numel() * sizeof(typename Atan2Out::type))); diff --git a/python/paddle/fluid/tests/unittests/test_atan2_op.py b/python/paddle/fluid/tests/unittests/test_atan2_op.py index 77ad77e3252b88dc5beb3a53e7f586e6ff3153ba..6b62b25ac5d8ae3263da5827b3be02170330fa55 100644 --- a/python/paddle/fluid/tests/unittests/test_atan2_op.py +++ b/python/paddle/fluid/tests/unittests/test_atan2_op.py @@ -130,6 +130,18 @@ class TestAtan2API(unittest.TestCase): run(place) +class TestAtan2Error(unittest.TestCase): + def test_mismatch(self): + paddle.enable_static() + + def test_mismatch_numel(): + X = paddle.fluid.data('X', (1,), dtype=np.float64) + Y = paddle.fluid.data('Y', (0,), dtype=np.float64) + out = paddle.atan2(X, Y) + + self.assertRaises(ValueError, test_mismatch_numel) + + if __name__ == '__main__': paddle.enable_static() unittest.main()