From 185494171cd01e9499fe556d11f325b3c2c3872c Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 11 Nov 2022 19:54:09 +0800 Subject: [PATCH] [Zero-Dim] fix batch_norm op infermeta bug (#47858) --- paddle/fluid/operators/batch_norm_op.cc | 3 +++ .../operators/common_infer_shape_functions.cc | 13 +++++----- .../operators/elementwise/elementwise_npu.h | 13 +++++----- paddle/phi/infermeta/multiary.cc | 3 +++ paddle/phi/kernels/funcs/common_shape.h | 13 +++++----- paddle/phi/kernels/funcs/elementwise_base.h | 26 ++++++++++--------- .../phi/kernels/funcs/elementwise_grad_base.h | 26 ++++++++++--------- paddle/phi/kernels/xpu/elementwise.h | 26 ++++++++++--------- 8 files changed, 69 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index cec3368a51..ccf5aa6a62 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); + if (ctx->HasInput("ReserveSpace")) { + ctx->SetOutputDim("ReserveSpace", {-1}); + } } framework::OpKernelType BatchNormOp::GetExpectedKernelType( diff --git a/paddle/fluid/operators/common_infer_shape_functions.cc b/paddle/fluid/operators/common_infer_shape_functions.cc index 9dce94d16b..b256d94a5a 100644 --- a/paddle/fluid/operators/common_infer_shape_functions.cc +++ b/paddle/fluid/operators/common_infer_shape_functions.cc @@ -40,12 +40,13 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - platform::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + platform::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); if (x_dims.size() > y_dims.size()) { std::fill(y_dims_array, y_dims_array + axis, 1); if (axis + y_dims.size() < max_dim) { diff --git a/paddle/fluid/operators/elementwise/elementwise_npu.h b/paddle/fluid/operators/elementwise/elementwise_npu.h index 45e5a548f9..b7e85c45f4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_npu.h +++ b/paddle/fluid/operators/elementwise/elementwise_npu.h @@ -123,12 +123,13 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx, platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - platform::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + platform::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); for (int i = 0; i < x_dims.size(); ++i) { dst_dims_vec[i + x_axis] = diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index c90c3a54de..b9e84c2df5 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x, if (saved_variance) { saved_variance->set_dims({C}); } + if (reserve_space) { + reserve_space->set_dims({-1}); + } y->share_lod(x); y->set_dtype(x.dtype()); } diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 01b0612096..f7524320e8 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -45,12 +45,13 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - phi::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + phi::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); if (x_dims.size() > y_dims.size()) { std::fill(y_dims_array, y_dims_array + axis, 1); if (axis + y_dims.size() < max_dim) { diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 29da617413..a1b0a70956 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -326,12 +326,13 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx, phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - phi::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + phi::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); std::vector x_dims_array(max_dim); std::vector y_dims_array(max_dim); std::vector out_dims_array(max_dim); @@ -394,12 +395,13 @@ void ElementwiseCompute(const CPUContext &dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index f8007319d6..c55ce6a89a 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -287,12 +287,13 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { @@ -1725,12 +1726,13 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { diff --git a/paddle/phi/kernels/xpu/elementwise.h b/paddle/phi/kernels/xpu/elementwise.h index dfaaae59bb..46bac6ce29 100644 --- a/paddle/phi/kernels/xpu/elementwise.h +++ b/paddle/phi/kernels/xpu/elementwise.h @@ -51,12 +51,13 @@ void XPUElementwise(const XPUContext& dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); std::vector x_dims_vec(max_dim, 1); std::vector y_dims_vec(max_dim, 1); if (x_dims.size() == max_dim) { @@ -121,12 +122,13 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); std::vector x_dims_vec(max_dim, 1); std::vector y_dims_vec(max_dim, 1); if (x_dims.size() == max_dim) { -- GitLab