diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc deleted file mode 100644 index bddd9e80544c44a057ef2bb24575b14a0352d5c9..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/batch_norm_op_xpu.cc +++ /dev/null @@ -1,430 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include -#include - -#include "paddle/fluid/operators/batch_norm_op.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -template -class BatchNormXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto epsilon = ctx.Attr("epsilon"); - float momentum = ctx.Attr("momentum"); - const auto is_test = ctx.Attr("is_test"); - const auto use_global_stats = ctx.Attr("use_global_stats"); - const auto trainable_stats = ctx.Attr("trainable_statistics"); - bool test_mode = is_test && (!trainable_stats); - - bool global_stats = test_mode || use_global_stats; - const auto &data_layout_str = ctx.Attr("data_layout"); - const auto data_layout = framework::StringToDataLayout(data_layout_str); - PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC", - true, - platform::errors::InvalidArgument( - "The 'data_layout' attribute must be NCHW or NHWC. " - "But recevived 'data_layout' is [%s].", - data_layout_str)); - - const auto *x = ctx.Input("X"); - const auto &x_dims = x->dims(); - PADDLE_ENFORCE_EQ( - x_dims.size() >= 2 && x_dims.size() <= 5, - true, - platform::errors::InvalidArgument( - "The size of input's dimensions should be between 2 and 5" - "But received: the size of input's dimensions is [%d]", - x_dims.size())); - - int N = -1, C = -1, H = -1, W = -1, D = -1; - ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); - N = (N == 0) ? 1 : N; - C = (C == 0) ? 1 : C; - H = (H == 0) ? 1 : H; - W = (W == 0) ? 1 : W; - - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - const auto *x_data = x->data(); - const auto *scale_data = scale->data(); - const auto *bias_data = bias->data(); - - auto *y = ctx.Output("Y"); - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - auto *saved_mean = ctx.Output("SavedMean"); - auto *saved_variance = ctx.Output("SavedVariance"); - - // alloc memory - auto *y_data = y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data(ctx.GetPlace()); - variance_out->mutable_data(ctx.GetPlace()); - saved_mean->mutable_data(ctx.GetPlace()); - saved_variance->mutable_data(ctx.GetPlace()); - - auto &dev_ctx = ctx.template device_context(); - bool is_nchw = data_layout_str == "NCHW"; - - if (!global_stats) { - auto *mean_out_data = mean_out->data(); - auto *variance_out_data = variance_out->data(); - auto *saved_mean_data = saved_mean->data(); - auto *saved_variance_data = saved_variance->data(); - - // if MomentumTensor is set, use MomentumTensor value, momentum - // is only used in this training branch - if (ctx.HasInput("MomentumTensor")) { - const auto *mom_tensor = ctx.Input("MomentumTensor"); - Tensor mom_cpu; - paddle::framework::TensorCopySync( - *mom_tensor, platform::CPUPlace(), &mom_cpu); - momentum = mom_tensor->data()[0]; - } - - int r = xpu::batch_norm(dev_ctx.x_context(), - x_data, - y_data, - N, - C, - H, - W, - epsilon, - momentum, - scale_data, - bias_data, - saved_mean_data, - saved_variance_data, - mean_out_data, - variance_out_data, - is_nchw); - PADDLE_ENFORCE_EQ(r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "The batch_norm XPU API return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } else { - const auto *mean = ctx.Input("Mean"); - const auto *variance = ctx.Input("Variance"); - const auto *mean_data = mean->data(); - const auto *variance_data = variance->data(); - int r = xpu::batch_norm_infer(dev_ctx.x_context(), - x_data, - y_data, - N, - C, - H, - W, - epsilon, - scale_data, - bias_data, - mean_data, - variance_data, - is_nchw); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "The batch_norm_infer XPU API return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } - } -}; - -template -static int calculate_inv_BN_Y(xpu::Context *ctx, - T *x, - const T *scale, - const T *bias, - const T *mean, - const T *variance, - const int N, - const int C, - const int M, - const T *y) { - PADDLE_ENFORCE_EQ(x, - y, - platform::errors::InvalidArgument( - "X and Y should be inplaced in inplace mode")); - std::vector tensor_shape_vec({N, C, M}); - std::vector array_shape_vec({1, C, 1}); - // y - bias - int r1 = - xpu::broadcast_sub(ctx, bias, y, x, array_shape_vec, tensor_shape_vec); - // (y - bias) / scale - int r2 = xpu::broadcast_div( - ctx, scale, x, x, array_shape_vec, tensor_shape_vec); - // (y - bias) / scale / variance - int r3 = xpu::broadcast_div( - ctx, variance, x, x, array_shape_vec, tensor_shape_vec); - // (y - bias) / scale / variance + mean - int r4 = - xpu::broadcast_add(ctx, mean, x, x, array_shape_vec, tensor_shape_vec); - - return r1 + r2 + r3 + r4; -} - -template -static int calculate_inv_var(xpu::Context *ctx, - const T *var, - const T epsilon, - const int C, - T *epsilon_data, - T *inv_var) { - int r1 = constant(ctx, epsilon_data, 1, epsilon); - std::vector tensor_shape_vec({C}); - std::vector array_shape_vec({1}); - int r2 = xpu::broadcast_add( - ctx, epsilon_data, var, inv_var, array_shape_vec, tensor_shape_vec); - int r3 = xpu::rsqrt(ctx, inv_var, inv_var, C); - return r1 + r2 + r3; -} - -template -class BatchNormGradXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto *d_y = ctx.Input(framework::GradVarName("Y")); - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - - const auto &data_layout_str = ctx.Attr("data_layout"); - bool use_global_stats = ctx.Attr("use_global_stats"); - const bool is_test = ctx.Attr("is_test"); - const float epsilon = ctx.Attr("epsilon"); - const auto data_layout = framework::StringToDataLayout(data_layout_str); - - PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC", - true, - platform::errors::InvalidArgument( - "The 'data_layout' attribute must be NCHW or NHWC. " - "But recevived 'data_layout' is [%s].", - data_layout_str)); - - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - - use_global_stats = is_test || use_global_stats; - - // batch_norm with inplace as false will take X as grad input, which - // is same as cuDNN batch_norm backward calculation, batch_norm - // with inplace as true only take Y as input and X should be calculate - // by inverse operation of batch_norm on Y - const Tensor *x; - bool is_inplace; - if (ctx.HasInput("Y")) { - x = ctx.Input("Y"); - is_inplace = true; - // if the input of batch norm is stop_gradient, d_x is null. - if (d_x) { - PADDLE_ENFORCE_EQ(d_x, - d_y, - platform::errors::InvalidArgument( - "X@GRAD and Y@GRAD not inplace in inplace mode")); - } - } else { - x = ctx.Input("X"); - is_inplace = false; - if (d_x) { - PADDLE_ENFORCE_NE( - d_x, - d_y, - platform::errors::InvalidArgument( - "X@GRAD and Y@GRAD inplaced in non-inplace mode")); - } - } - - const auto &x_dims = x->dims(); - PADDLE_ENFORCE_EQ( - x_dims.size() >= 2 && x_dims.size() <= 5, - true, - platform::errors::InvalidArgument( - "The size of input's dimensions should be between 2 and 5" - "But received: the size of input's dimensions is [%d]", - x_dims.size())); - - int N = -1, C = -1, H = -1, W = -1, D = -1; - ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); - N = (N == 0) ? 1 : N; - C = (C == 0) ? 1 : C; - H = (H == 0) ? 1 : H; - W = (W == 0) ? 1 : W; - - const auto *x_data = x->data(); - const auto *d_y_data = d_y->data(); - const auto *scale_data = scale->data(); - - // init output - T *d_x_data = nullptr; - T *d_bias_data = nullptr; - T *d_scale_data = nullptr; - if (d_x) { - d_x_data = d_x->mutable_data(ctx.GetPlace()); - } - if (d_scale && d_bias) { - d_scale_data = d_scale->mutable_data(ctx.GetPlace()); - d_bias_data = d_bias->mutable_data(ctx.GetPlace()); - } - - PADDLE_ENFORCE_EQ( - scale->dims().size(), - 1UL, - platform::errors::InvalidArgument( - "The size of scale's dimensions must equal to 1. But received: " - "the size of scale's dimensions is [%d], the dimensions of scale " - "is [%s].", - scale->dims().size(), - scale->dims())); - PADDLE_ENFORCE_EQ( - scale->dims()[0], - C, - platform::errors::InvalidArgument( - "The first dimension of scale must equal to Channels[%d]. But " - "received: the first dimension of scale is [%d]", - C, - scale->dims()[0])); - - auto &dev_ctx = ctx.template device_context(); - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - - const auto *batch_mean = ctx.Input("SavedMean"); - const auto *batch_inv_std = ctx.Input("SavedVariance"); - const auto *global_mean = ctx.Input("Mean"); - const auto *global_var = ctx.Input("Variance"); - - // TODO(guozibin): hadle the situation case of N * H * W = 1 - if (is_inplace) { - float *global_inv_std_data = nullptr; - if (use_global_stats) { - global_inv_std_data = - RAII_GUARD.alloc_l3_or_gm(global_var->numel()); - float *epsilon_data = RAII_GUARD.alloc_l3_or_gm(1); - int r1 = calculate_inv_var(dev_ctx.x_context(), - global_var->data(), - epsilon, - C, - epsilon_data, - global_inv_std_data); - PADDLE_ENFORCE_EQ( - r1, - XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_grad " - "calculate_inv_var function) " - "return wrong value[%d %s]", - r1, - XPUAPIErrorMsg[r1])); - } - auto px = *x; - auto *inv_std_data = - use_global_stats ? global_inv_std_data : batch_inv_std->data(); - auto mean_data = use_global_stats ? global_mean->data() - : batch_mean->data(); - int r2 = calculate_inv_BN_Y(dev_ctx.x_context(), - px.mutable_data(ctx.GetPlace()), - scale->data(), - bias->data(), - mean_data, - inv_std_data, - N, - C, - H * W, - x->data()); - PADDLE_ENFORCE_EQ( - r2, - XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_grad " - "calculate_inv_BN_Y function) " - "return wrong value[%d %s]", - r2, - XPUAPIErrorMsg[r2])); - } - - int r3; - bool is_nchw = data_layout_str == "NCHW"; - if (use_global_stats) { - r3 = xpu::batch_norm_grad(dev_ctx.x_context(), - x_data, - d_y_data, - d_x_data, - N, - C, - H, - W, - scale_data, - nullptr, - nullptr, - d_scale_data, - d_bias_data, - is_nchw, - global_mean->data(), - global_var->data(), - epsilon); - } else { - if (!d_x) { - d_x_data = RAII_GUARD.alloc_l3_or_gm(x->numel()); - } - if (!d_scale) { - d_scale_data = RAII_GUARD.alloc_l3_or_gm(C); - } - if (!d_bias_data) { - d_bias_data = RAII_GUARD.alloc_l3_or_gm(C); - } - r3 = xpu::batch_norm_grad(dev_ctx.x_context(), - x_data, - d_y_data, - d_x_data, - N, - C, - H, - W, - scale_data, - batch_mean->data(), - batch_inv_std->data(), - d_scale_data, - d_bias_data, - is_nchw); - } - PADDLE_ENFORCE_EQ( - r3, - XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_grad) return " - "wrong value[%d %s]", - r3, - XPUAPIErrorMsg[r3])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL( - batch_norm, - ops::BatchNormXPUKernel); -REGISTER_OP_XPU_KERNEL( - batch_norm_grad, - ops::BatchNormGradXPUKernel); - -#endif // PADDLE_WITH_XPU diff --git a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..ccb9f601ed33296d34fd2a1d35663058aa775d40 --- /dev/null +++ b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc @@ -0,0 +1,277 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/batch_norm_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" + +namespace phi { + +template +static int CalculateInvBNY(xpu::Context *ctx, + T *x, + const T *scale, + const T *bias, + const T *mean, + const T *variance, + const int N, + const int C, + const int M, + const T *y) { + PADDLE_ENFORCE_EQ(x, + y, + phi::errors::InvalidArgument( + "X and Y should be inplaced in inplace mode")); + std::vector tensor_shape_vec({N, C, M}); + std::vector array_shape_vec({1, C, 1}); + // y - bias + int r1 = + xpu::broadcast_sub(ctx, bias, y, x, array_shape_vec, tensor_shape_vec); + // (y - bias) / scale + int r2 = xpu::broadcast_div( + ctx, scale, x, x, array_shape_vec, tensor_shape_vec); + // (y - bias) / scale / variance + int r3 = xpu::broadcast_div( + ctx, variance, x, x, array_shape_vec, tensor_shape_vec); + // (y - bias) / scale / variance + mean + int r4 = + xpu::broadcast_add(ctx, mean, x, x, array_shape_vec, tensor_shape_vec); + + return r1 + r2 + r3 + r4; +} + +template +static int CalculateInvVar(xpu::Context *ctx, + const T *var, + const T epsilon, + const int C, + T *epsilon_data, + T *inv_var) { + int r1 = constant(ctx, epsilon_data, 1, epsilon); + std::vector tensor_shape_vec({C}); + std::vector array_shape_vec({1}); + int r2 = xpu::broadcast_add( + ctx, epsilon_data, var, inv_var, array_shape_vec, tensor_shape_vec); + int r3 = xpu::rsqrt(ctx, inv_var, inv_var, C); + return r1 + r2 + r3; +} + +template +void BatchNormGradKernel(const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &bias, + const paddle::optional &mean, + const paddle::optional &variance, + const DenseTensor &saved_mean, + const DenseTensor &saved_variance, + const paddle::optional &reserve_space, + const DenseTensor &y_grad, + float momentum, + float epsilon, + const std::string &data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor *x_grad, + DenseTensor *scale_grad, + DenseTensor *bias_grad) { + const auto *d_y = &y_grad; + PADDLE_ENFORCE_EQ(data_layout == "NCHW" || data_layout == "NHWC", + true, + phi::errors::InvalidArgument( + "The 'data_layout' attribute must be NCHW or NHWC. " + "But recevived 'data_layout' is [%s].", + data_layout)); + + const auto data_layout_val = + paddle::framework::StringToDataLayout(data_layout); + + use_global_stats = is_test || use_global_stats; + + // batch_norm with inplace as false will take X as grad input, which + // is same as cuDNN batch_norm backward calculation, batch_norm + // with inplace as true only take Y as input and X should be calculate + // by inverse operation of batch_norm on Y + bool is_inplace = false; + if (x_grad) { + PADDLE_ENFORCE_NE(x_grad, + d_y, + phi::errors::InvalidArgument( + "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + } + + const auto &x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size() >= 2 && x_dims.size() <= 5, + true, + phi::errors::InvalidArgument( + "The size of input's dimensions should be between 2 and 5" + "But received: the size of input's dimensions is [%d]", + x_dims.size())); + + int N = -1, C = -1, H = -1, W = -1, D = -1; + funcs::ExtractNCWHD(x_dims, data_layout_val, &N, &C, &H, &W, &D); + N = (N == 0) ? 1 : N; + C = (C == 0) ? 1 : C; + H = (H == 0) ? 1 : H; + W = (W == 0) ? 1 : W; + + const auto *x_data = x.data(); + const auto *d_y_data = y_grad.data(); + const auto *scale_data = scale.data(); + + // init output + T *x_grad_data = nullptr; + T *bias_grad_data = nullptr; + T *scale_grad_data = nullptr; + if (x_grad) { + x_grad_data = dev_ctx.template Alloc(x_grad); + } + if (scale_grad && bias_grad) { + scale_grad_data = dev_ctx.template Alloc(scale_grad); + bias_grad_data = dev_ctx.template Alloc(bias_grad); + } + + PADDLE_ENFORCE_EQ( + scale.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "The size of scale's dimensions must equal to 1. But received: " + "the size of scale's dimensions is [%d], the dimensions of scale " + "is [%s].", + scale.dims().size(), + scale.dims())); + PADDLE_ENFORCE_EQ( + scale.dims()[0], + C, + phi::errors::InvalidArgument( + "The first dimension of scale must equal to Channels[%d]. But " + "received: the first dimension of scale is [%d]", + C, + scale.dims()[0])); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + + const auto *global_mean = mean.get_ptr(); + const auto *global_var = variance.get_ptr(); + + // TODO(guozibin): hadle the situation case of N * H * W = 1 + if (is_inplace) { + float *global_inv_std_data = nullptr; + if (use_global_stats) { + global_inv_std_data = + RAII_GUARD.alloc_l3_or_gm(global_var->numel()); + float *epsilon_data = RAII_GUARD.alloc_l3_or_gm(1); + int r1 = CalculateInvVar(dev_ctx.x_context(), + global_var->data(), + epsilon, + C, + epsilon_data, + global_inv_std_data); + PADDLE_ENFORCE_EQ(r1, + XPU_SUCCESS, + phi::errors::External("XPU API(batch_norm_grad " + "CalculateInvVar function) " + "return wrong value[%d %s]", + r1, + XPUAPIErrorMsg[r1])); + } + + // Here is a trick, x is a const input, + // but trans to a non-const var, is it risky? + auto px = x; + auto *inv_std_data = + use_global_stats ? global_inv_std_data : saved_variance.data(); + auto *mean_data = use_global_stats ? global_mean->data() + : saved_mean.data(); + int r2 = CalculateInvBNY(dev_ctx.x_context(), + px.data(), + scale.data(), + bias.data(), + mean_data, + inv_std_data, + N, + C, + H * W, + x.data()); + PADDLE_ENFORCE_EQ(r2, + XPU_SUCCESS, + phi::errors::External("XPU API(batch_norm_grad " + "CalculateInvBNY function) " + "return wrong value[%d %s]", + r2, + XPUAPIErrorMsg[r2])); + } + + int r3; + bool is_nchw = data_layout == "NCHW"; + if (use_global_stats) { + r3 = xpu::batch_norm_grad(dev_ctx.x_context(), + x_data, + d_y_data, + x_grad_data, + N, + C, + H, + W, + scale_data, + nullptr, + nullptr, + scale_grad_data, + bias_grad_data, + is_nchw, + global_mean->data(), + global_var->data(), + epsilon); + } else { + if (!x_grad) { + x_grad_data = RAII_GUARD.alloc_l3_or_gm(x.numel()); + } + if (!scale_grad) { + scale_grad_data = RAII_GUARD.alloc_l3_or_gm(C); + } + if (!bias_grad_data) { + bias_grad_data = RAII_GUARD.alloc_l3_or_gm(C); + } + r3 = xpu::batch_norm_grad(dev_ctx.x_context(), + x_data, + d_y_data, + x_grad_data, + N, + C, + H, + W, + scale_data, + saved_mean.data(), + saved_variance.data(), + scale_grad_data, + bias_grad_data, + is_nchw); + } + PADDLE_ENFORCE_EQ(r3, + XPU_SUCCESS, + phi::errors::External("XPU API(batch_norm_grad) return " + "wrong value[%d %s]", + r3, + XPUAPIErrorMsg[r3])); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + batch_norm_grad, XPU, ALL_LAYOUT, phi::BatchNormGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/batch_norm_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d29e56a7d7a026c93652f43a17abcd8ddf9df172 --- /dev/null +++ b/paddle/phi/kernels/xpu/batch_norm_kernel.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/batch_norm_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" + +namespace phi { + +template +void BatchNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout_str, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space) { + bool test_mode = is_test && (!trainable_statistics); + bool global_stats = test_mode || use_global_stats; + const auto data_layout = + paddle::framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC", + true, + phi::errors::InvalidArgument( + "The 'data_layout' attribute must be NCHW or NHWC. " + "But recevived 'data_layout' is [%s].", + data_layout_str)); + + const auto& x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size() >= 2 && x_dims.size() <= 5, + true, + phi::errors::InvalidArgument( + "The size of input's dimensions should be between 2 and 5" + "But received: the size of input's dimensions is [%d]", + x_dims.size())); + + int N = -1, C = -1, H = -1, W = -1, D = -1; + funcs::ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + N = (N == 0) ? 1 : N; + C = (C == 0) ? 1 : C; + H = (H == 0) ? 1 : H; + W = (W == 0) ? 1 : W; + + const auto* x_data = x.data(); + const auto* scale_data = scale.data(); + const auto* bias_data = bias.data(); + + // alloc memory + auto* y_data = dev_ctx.template Alloc(y); + dev_ctx.template Alloc(mean_out); + dev_ctx.template Alloc(variance_out); + dev_ctx.template Alloc(saved_mean); + dev_ctx.template Alloc(saved_variance); + + bool is_nchw = data_layout_str == "NCHW"; + + if (!global_stats) { + auto* mean_out_data = mean_out->data(); + auto* variance_out_data = variance_out->data(); + auto* saved_mean_data = saved_mean->data(); + auto* saved_variance_data = saved_variance->data(); + + int r = xpu::batch_norm(dev_ctx.x_context(), + x_data, + y_data, + N, + C, + H, + W, + epsilon, + momentum, + scale_data, + bias_data, + saved_mean_data, + saved_variance_data, + mean_out_data, + variance_out_data, + is_nchw); + PADDLE_ENFORCE_EQ(r, + xpu::Error_t::SUCCESS, + phi::errors::External( + "The batch_norm XPU API return wrong value[%d %s]", + r, + XPUAPIErrorMsg[r])); + } else { + const auto* mean_data = mean.data(); + const auto* variance_data = variance.data(); + int r = xpu::batch_norm_infer(dev_ctx.x_context(), + x_data, + y_data, + N, + C, + H, + W, + epsilon, + scale_data, + bias_data, + mean_data, + variance_data, + is_nchw); + PADDLE_ENFORCE_EQ( + r, + xpu::Error_t::SUCCESS, + phi::errors::External( + "The batch_norm_infer XPU API return wrong value[%d %s]", + r, + XPUAPIErrorMsg[r])); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(batch_norm, XPU, ALL_LAYOUT, phi::BatchNormKernel, float) {}