diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index e09f94a6c0fee08290d465c9a46d0334389b2a92..d6cf58f7a157f31de6440291213e3293289d8651 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -85,6 +85,14 @@ class ElementwiseOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); int max_dim = std::max(x_dims.size(), y_dims.size()); int axis = ctx->Attrs().Get("axis"); + if (x_dims.size() == y_dims.size()) { + PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), true, + platform::errors::InvalidArgument( + "axis should be -1 or 0 while the dimension of " + "tensor X (%s) is equal to the dimension of " + "tensor Y (%s), but received axis: %s", + x_dims.size(), y_dims.size(), axis)); + } PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim), true, platform::errors::InvalidArgument( "The axis range must be [%s, %s), but axis is %s. "