diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index cec3368a51867c3497b38a5271e9856861544232..ccf5aa6a62268be87d97661989aec31f822e2218 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -164,6 +164,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); + if (ctx->HasInput("ReserveSpace")) { + ctx->SetOutputDim("ReserveSpace", {-1}); + } } framework::OpKernelType BatchNormOp::GetExpectedKernelType( diff --git a/paddle/fluid/operators/common_infer_shape_functions.cc b/paddle/fluid/operators/common_infer_shape_functions.cc index 9dce94d16b4db917f46931d57c9ccf768a3ffee8..b256d94a5a894f6cf126c7821da356b5341ab839 100644 --- a/paddle/fluid/operators/common_infer_shape_functions.cc +++ b/paddle/fluid/operators/common_infer_shape_functions.cc @@ -40,12 +40,13 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - platform::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + platform::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); if (x_dims.size() > y_dims.size()) { std::fill(y_dims_array, y_dims_array + axis, 1); if (axis + y_dims.size() < max_dim) { diff --git a/paddle/fluid/operators/elementwise/elementwise_npu.h b/paddle/fluid/operators/elementwise/elementwise_npu.h index 45e5a548f91ede3c5800d4544b42c6756d7aeb26..b7e85c45f4c7ccba142f91be763d129686d40502 100644 --- a/paddle/fluid/operators/elementwise/elementwise_npu.h +++ b/paddle/fluid/operators/elementwise/elementwise_npu.h @@ -123,12 +123,13 @@ void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx, platform::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - platform::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + platform::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); for (int i = 0; i < x_dims.size(); ++i) { dst_dims_vec[i + x_axis] = diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index c90c3a54de06f67d9bbf7a424fae5ea6db99ce34..b9e84c2df57141108fd619162cf87ad71e89ef59 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -640,6 +640,9 @@ void BatchNormInferMeta(const MetaTensor& x, if (saved_variance) { saved_variance->set_dims({C}); } + if (reserve_space) { + reserve_space->set_dims({-1}); + } y->share_lod(x); y->set_dtype(x.dtype()); } diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 01b06120965fc910c3856d9241d753c8bb1cd87a..f7524320e8807fec9a6567c020638ce5a64141cd 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -45,12 +45,13 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - phi::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + phi::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); if (x_dims.size() > y_dims.size()) { std::fill(y_dims_array, y_dims_array + axis, 1); if (axis + y_dims.size() < max_dim) { diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 29da61741385302ad943c25c6bf12ec77327c9b7..a1b0a70956931d41c7ce14a33d8c026003e69ed6 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -326,12 +326,13 @@ void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx, phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - phi::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + phi::errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); std::vector x_dims_array(max_dim); std::vector y_dims_array(max_dim); std::vector out_dims_array(max_dim); @@ -394,12 +395,13 @@ void ElementwiseCompute(const CPUContext &dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index f8007319d697ce12d6087993e41806760f1a5606..c55ce6a89ae1c46bc36e78d3d8b076baee3d599d 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -287,12 +287,13 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext &ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { @@ -1725,12 +1726,13 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { diff --git a/paddle/phi/kernels/xpu/elementwise.h b/paddle/phi/kernels/xpu/elementwise.h index dfaaae59bb3cef6a853c819179eb872092c2ecfc..46bac6ce299147550e56f1e025feceb2534095fc 100644 --- a/paddle/phi/kernels/xpu/elementwise.h +++ b/paddle/phi/kernels/xpu/elementwise.h @@ -51,12 +51,13 @@ void XPUElementwise(const XPUContext& dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); std::vector x_dims_vec(max_dim, 1); std::vector y_dims_vec(max_dim, 1); if (x_dims.size() == max_dim) { @@ -121,12 +122,13 @@ void XPUElementwiseGrad(const XPUContext& dev_ctx, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); - PADDLE_ENFORCE_LE(axis, - max_dim, - errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", - max_dim, - axis)); + PADDLE_ENFORCE_LE( + axis, + max_dim, + errors::InvalidArgument( + "Axis should be less than or equal to %d, but received axis is %d.", + max_dim, + axis)); std::vector x_dims_vec(max_dim, 1); std::vector y_dims_vec(max_dim, 1); if (x_dims.size() == max_dim) {