From 25871e0eeb81238112da3f11e80e55bc908967f0 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Thu, 2 Sep 2021 13:12:49 +0800 Subject: [PATCH] add axis check for elementwise op while the dimension of x is equal to the dimension of tensor (#35340) --- paddle/fluid/operators/elementwise/elementwise_op.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index e09f94a6c0..d6cf58f7a1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -85,6 +85,14 @@ class ElementwiseOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); int max_dim = std::max(x_dims.size(), y_dims.size()); int axis = ctx->Attrs().Get("axis"); + if (x_dims.size() == y_dims.size()) { + PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), true, + platform::errors::InvalidArgument( + "axis should be -1 or 0 while the dimension of " + "tensor X (%s) is equal to the dimension of " + "tensor Y (%s), but received axis: %s", + x_dims.size(), y_dims.size(), axis)); + } PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim), true, platform::errors::InvalidArgument( "The axis range must be [%s, %s), but axis is %s. " -- GitLab