未验证 提交 82edc65b 编写于 作者: R RedContritio 提交者: GitHub

Fix 空指针 (Null pointer) of case 14 paddle.atan2 (#49973)

* add elements count check in atan2

* add unittest and pre-check in inferMeta

* add dimension check
上级 7bb67db3
...@@ -142,6 +142,26 @@ void KLDivInferMeta(const MetaTensor& x, ...@@ -142,6 +142,26 @@ void KLDivInferMeta(const MetaTensor& x,
} }
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
y_dims.size(),
phi::errors::InvalidArgument("The rank (%d) of X shall be same as "
"rank (%d) of Y.",
x_dims.size(),
y_dims.size()));
if (x_dims.size() > 0)
PADDLE_ENFORCE_LE(x_dims[0],
y_dims[0],
phi::errors::InvalidArgument(
"The count (%d) of elements of X shall not "
"greater than count (%d) of elements of Y.",
x_dims[0],
y_dims[0]));
out->share_meta(x); out->share_meta(x);
if (x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64 || if (x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64 ||
y.dtype() == DataType::INT32 || y.dtype() == DataType::INT64) { y.dtype() == DataType::INT32 || y.dtype() == DataType::INT64) {
......
...@@ -77,6 +77,14 @@ void Atan2Kernel(const Context& ctx, ...@@ -77,6 +77,14 @@ void Atan2Kernel(const Context& ctx,
auto x_data = x.data<T>(); auto x_data = x.data<T>();
auto y_data = y.data<T>(); auto y_data = y.data<T>();
PADDLE_ENFORCE_LE(
numel,
y.numel(),
phi::errors::InvalidArgument("The count (%d) of elements of X shall not "
"greater than count (%d) of elements of Y.",
numel,
y.numel()));
auto* out_data = ctx.template Alloc<typename Atan2Out<T>::type>( auto* out_data = ctx.template Alloc<typename Atan2Out<T>::type>(
out, size_t(x.numel() * sizeof(typename Atan2Out<T>::type))); out, size_t(x.numel() * sizeof(typename Atan2Out<T>::type)));
......
...@@ -130,6 +130,18 @@ class TestAtan2API(unittest.TestCase): ...@@ -130,6 +130,18 @@ class TestAtan2API(unittest.TestCase):
run(place) run(place)
class TestAtan2Error(unittest.TestCase):
def test_mismatch(self):
paddle.enable_static()
def test_mismatch_numel():
X = paddle.fluid.data('X', (1,), dtype=np.float64)
Y = paddle.fluid.data('Y', (0,), dtype=np.float64)
out = paddle.atan2(X, Y)
self.assertRaises(ValueError, test_mismatch_numel)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册