diff --git a/paddle/fluid/operators/inplace_abn_op.cu b/paddle/fluid/operators/inplace_abn_op.cu index 71e21a2edd47bcae5d5428bd1a599a1a8093174a..a74150a3306726d405879c1f75369f890442badc 100644 --- a/paddle/fluid/operators/inplace_abn_op.cu +++ b/paddle/fluid/operators/inplace_abn_op.cu @@ -13,17 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/inplace_abn_op.h" +#include #include "paddle/fluid/operators/batch_norm_op.h" -#include "paddle/fluid/operators/sync_batch_norm_op.cu.h" #include "paddle/phi/kernels/batch_norm_grad_kernel.h" #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" +#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h" +#include "paddle/phi/kernels/sync_batch_norm_kernel.h" namespace paddle { namespace operators { template -class InplaceABNKernel - : public paddle::operators::SyncBatchNormKernel { +class InplaceABNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* y = ctx.Output("Y"); @@ -36,29 +38,49 @@ class InplaceABNKernel GetInplaceABNActivationType(ctx.Attr("activation")); auto& place = *ctx.template device_context().eigen_device(); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* mean = ctx.Input("Mean"); + auto* variance = ctx.Input("Variance"); + + auto momentum = ctx.Attr("momentum"); + auto epsilon = ctx.Attr("epsilon"); + auto data_layout = ctx.Attr("data_layout"); + auto is_test = ctx.Attr("is_test"); + auto use_global_stats = ctx.Attr("use_global_stats"); + auto trainable_statistics = ctx.Attr("trainable_statistics"); + auto fuse_with_relu = ctx.Attr("fuse_with_relu"); + + auto* mean_out = ctx.Output("MeanOut"); + auto* variance_out = ctx.Output("VarianceOut"); + auto* saved_mean = ctx.Output("SavedMean"); + auto* saved_variance = ctx.Output("SavedVariance"); + auto* reserve_space = ctx.Output("ReserveSpace"); + if (ctx.Attr("use_sync_bn")) { - SyncBatchNormKernel::Compute(ctx); + auto& dev_ctx = ctx.device_context(); + phi::SyncBatchNormKernel( + static_cast::TYPE&>(dev_ctx), + *x, + *scale, + *bias, + *mean, + *variance, + momentum, + epsilon, + data_layout, + is_test, + use_global_stats, + trainable_statistics, + fuse_with_relu, + y, + mean_out, + variance_out, + saved_mean, + saved_variance, + reserve_space); } else { - // BatchNormKernel::Compute(ctx); - auto* scale = ctx.Input("Scale"); - auto* bias = ctx.Input("Bias"); - auto* mean = ctx.Input("Mean"); - auto* variance = ctx.Input("Variance"); - - auto momentum = ctx.Attr("momentum"); - auto epsilon = ctx.Attr("epsilon"); - auto data_layout = ctx.Attr("data_layout"); - auto is_test = ctx.Attr("is_test"); - auto use_global_stats = ctx.Attr("use_global_stats"); - auto trainable_statistics = ctx.Attr("trainable_statistics"); - auto fuse_with_relu = ctx.Attr("fuse_with_relu"); - - auto* mean_out = ctx.Output("MeanOut"); - auto* variance_out = ctx.Output("VarianceOut"); - auto* saved_mean = ctx.Output("SavedMean"); - auto* saved_variance = ctx.Output("SavedVariance"); - auto* reserve_space = ctx.Output("ReserveSpace"); - auto& dev_ctx = ctx.device_context(); phi::BatchNormKernel( static_cast -class InplaceABNGradKernel - : public paddle::operators::SyncBatchNormGradKernel { +class InplaceABNGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const auto* y = ctx.Input("Y"); @@ -115,29 +136,44 @@ class InplaceABNGradKernel InplaceABNActivation functor; functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* saved_mean = ctx.Input("SavedMean"); + auto* saved_variance = ctx.Input("SavedVariance"); + + auto momentum = ctx.Attr("momentum"); + auto epsilon = ctx.Attr("epsilon"); + auto data_layout = ctx.Attr("data_layout"); + auto is_test = ctx.Attr("is_test"); + auto use_global_stats = ctx.Attr("use_global_stats"); + auto trainable_statistics = ctx.Attr("trainable_statistics"); + auto fuse_with_relu = ctx.Attr("fuse_with_relu"); + + auto* scale_grad = ctx.Output(framework::GradVarName("Scale")); + auto* bias_grad = ctx.Output(framework::GradVarName("Bias")); + + auto* reserve_space = ctx.Input("ReserveSpace"); + auto* mean = ctx.Input("ReserveSpace"); + auto* variance = ctx.Input("ReserveSpace"); + if (ctx.Attr("use_sync_bn")) { - SyncBatchNormGradKernel::Compute(ctx); + auto& dev_ctx = ctx.device_context(); + phi::SyncBatchNormGradFunctor( + static_cast::TYPE&>(dev_ctx), + nullptr, + y, + *scale, + *bias, + *saved_mean, + *saved_variance, + *d_y, + epsilon, + data_layout, + d_x, + scale_grad, + bias_grad); } else { - auto* scale = ctx.Input("Scale"); - auto* bias = ctx.Input("Bias"); - auto* saved_mean = ctx.Input("SavedMean"); - auto* saved_variance = ctx.Input("SavedVariance"); - - auto momentum = ctx.Attr("momentum"); - auto epsilon = ctx.Attr("epsilon"); - auto data_layout = ctx.Attr("data_layout"); - auto is_test = ctx.Attr("is_test"); - auto use_global_stats = ctx.Attr("use_global_stats"); - auto trainable_statistics = ctx.Attr("trainable_statistics"); - auto fuse_with_relu = ctx.Attr("fuse_with_relu"); - - auto* scale_grad = ctx.Output(framework::GradVarName("Scale")); - auto* bias_grad = ctx.Output(framework::GradVarName("Bias")); - - auto* reserve_space = ctx.Input("ReserveSpace"); - auto* mean = ctx.Input("ReserveSpace"); - auto* variance = ctx.Input("ReserveSpace"); - paddle::optional space_opt; paddle::optional mean_opt; paddle::optional variance_opt; diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu deleted file mode 100644 index 637064398e177afcf1baaf4872960d8a6ca4e069..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/sync_batch_norm_op.cu +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/sync_batch_norm_op.cu.h" - -namespace paddle { -namespace operators { - -template -class SyncBatchNormKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - double epsilon = static_cast(ctx.Attr("epsilon")); - const float momentum = ctx.Attr("momentum"); - const bool is_test = ctx.Attr("is_test"); - const std::string layout_str = ctx.Attr("data_layout"); - const DataLayout layout = framework::StringToDataLayout(layout_str); - const bool use_global_stats = ctx.Attr("use_global_stats"); - const bool trainable_stats = ctx.Attr("trainable_statistics"); - PADDLE_ENFORCE_EQ(use_global_stats, - false, - platform::errors::InvalidArgument( - "sync_batch_norm doesn't support " - "to set use_global_stats True. Please use batch_norm " - "in this case.")); - - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Y"); - - const auto *est_mean = ctx.Input("Mean"); - const auto *est_var = ctx.Input("Variance"); - - // moving mean/variance - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - - auto *saved_mean = ctx.Output("SavedMean"); - auto *saved_inv_variance = ctx.Output("SavedVariance"); - - bool test_mode = is_test && (!trainable_stats); - SyncBatchNormFunctor(ctx, - layout, - x, - y, - est_mean, - est_var, - mean_out, - variance_out, - saved_mean, - saved_inv_variance, - epsilon, - momentum, - test_mode, - use_global_stats); - } -}; - -template -class SyncBatchNormGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(ctx.GetPlace()), - true, - platform::errors::InvalidArgument("It must use CUDAPlace.")); - double epsilon = static_cast(ctx.Attr("epsilon")); - const std::string layout_str = ctx.Attr("data_layout"); - - const DataLayout layout = framework::StringToDataLayout(layout_str); - const auto *d_y = ctx.Input(framework::GradVarName("Y")); - const auto *scale = ctx.Input("Scale"); - const auto *bias = ctx.Input("Bias"); - - // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - - const auto *saved_mean = ctx.Input("SavedMean"); - const auto *saved_inv_var = ctx.Input("SavedVariance"); - - SyncBatchNormGradFunctor(ctx, - layout, - scale, - bias, - d_x, - d_y, - d_scale, - d_bias, - saved_mean, - saved_inv_var, - epsilon); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -#ifdef PADDLE_WITH_HIP -// MIOPEN do not support double -REGISTER_OP_CUDA_KERNEL( - sync_batch_norm, - ops::SyncBatchNormKernel, - ops::SyncBatchNormKernel); -REGISTER_OP_CUDA_KERNEL( - sync_batch_norm_grad, - ops::SyncBatchNormGradKernel, - ops::SyncBatchNormGradKernel); -#else -REGISTER_OP_CUDA_KERNEL( - sync_batch_norm, - ops::SyncBatchNormKernel, - ops::SyncBatchNormKernel, - ops::SyncBatchNormKernel); -REGISTER_OP_CUDA_KERNEL( - sync_batch_norm_grad, - ops::SyncBatchNormGradKernel, - ops::SyncBatchNormGradKernel, - ops::SyncBatchNormGradKernel); -#endif - -// clang-format on diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu.h b/paddle/fluid/operators/sync_batch_norm_op.cu.h deleted file mode 100644 index 47de27e876922928cf09450086de31e8dcef5a90..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/sync_batch_norm_op.cu.h +++ /dev/null @@ -1,637 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_layout.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/batch_norm_op.h" -#include "paddle/fluid/operators/norm_utils.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using DataLayout = framework::DataLayout; -template -using CudnnDataType = platform::CudnnDataType; -template -using BatchNormParamType = typename CudnnDataType::BatchNormParamType; - -template -__global__ void KeLocalStats( - const T *x, int N, int M, int C, BatchNormParamType *mean_var) { - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int k = blockIdx.x; k < C; k += gridDim.x) { - BatchNormParamType x_sum = 0.; - BatchNormParamType x2_sum = 0.; - for (int i = threadIdx.x; i < N * M; i += BlockDim) { - int id = layout == framework::DataLayout::kNCHW - ? (i / M) * C * M + k * M + i % M - : i * C + k; - auto x_in = static_cast>(x[id]); - x_sum += x_in; - x2_sum += x_in * x_in; - } - __syncthreads(); - auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - mean_var[k] = out / (N * M); - } - out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - mean_var[k + C] = out / (N * M); - } - } - if (blockIdx.x == 0 && threadIdx.x == 0) { - mean_var[2 * C] = static_cast>(1.0); - } -} - -template -__global__ void KeSyncAndMovingStats(BatchNormParamType *means, - BatchNormParamType *variances, - BatchNormParamType *num_dev, - const int C, - const BatchNormParamType momentum, - const double epsilon, - BatchNormParamType *sv_mean_data, - BatchNormParamType *sv_inv_var_data, - BatchNormParamType *moving_means, - BatchNormParamType *moving_variances) { - // sync stats across multi-devices - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (int i = gid; i < C; i += stride) { - auto mean = means[i] / (*num_dev); - auto var = variances[i] / (*num_dev); - var = var - mean * mean; - - // sync stats - sv_mean_data[i] = mean; - sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon); - variances[i] = var; - - // moving stats - moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum); - moving_variances[i] = - moving_variances[i] * momentum + var * (1. - momentum); - } -} - -template -static __global__ void KeNormAffine(const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const BatchNormParamType *mean, - const BatchNormParamType *variance, - const double epsilon, - const int C, - const int M, - const int num, - T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (int i = gid; i < num; i += stride) { - const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; - auto x_i = static_cast>(x[i]); - auto y_i = - (x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; - y[i] = static_cast(y_i); - } -} - -template -void SyncBatchNormFunctor(const framework::ExecutionContext &ctx, - const DataLayout layout, - const framework::Tensor *x, - framework::Tensor *y, - const framework::Tensor *mean, - const framework::Tensor *variance, - framework::Tensor *mean_out, - framework::Tensor *variance_out, - framework::Tensor *saved_mean, - framework::Tensor *saved_variance, - double epsilon, - const float momentum, - const bool is_test, - const bool use_global_stats - -) { - const auto &x_dims = x->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), - 2, - platform::errors::InvalidArgument( - "The Input dim size should be larger than 1.")); - PADDLE_ENFORCE_LE(x_dims.size(), - 5, - platform::errors::InvalidArgument( - "The Input dim size should be less than 6.")); - int N, C, H, W, D; - ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); - int x_numel = x->numel(); - - const T *x_d = x->data(); - const auto *s_d = ctx.Input("Scale")->data>(); - const auto *b_d = ctx.Input("Bias")->data>(); - - T *y_d = y->mutable_data(ctx.GetPlace()); - - const BatchNormParamType *mean_data = nullptr; - const BatchNormParamType *var_data = nullptr; - - auto &dev_ctx = ctx.cuda_device_context(); - auto stream = dev_ctx.stream(); - const int block = 512; - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - - paddle::memory::AllocationPtr alloc_ptr{nullptr}; - - if (is_test) { - mean_data = mean->data>(); - var_data = variance->data>(); - } else { - // x, x^2, 1, here 1 is used to calc device num - // device num also can be got from platform::DeviceContextPool - const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); - alloc_ptr = memory::Alloc(dev_ctx, bytes); - - auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); - const int threads = 256; - int grid = std::min(C, (max_threads + threads - 1) / threads); - if (layout == framework::DataLayout::kNCHW) { - KeLocalStats - <<>>(x_d, N, H * W * D, C, stats); - } else { - KeLocalStats - <<>>(x_d, N, H * W * D, C, stats); - } - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto *comm = dev_ctx.nccl_comm(); - if (comm) { - int dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(mean_out->dtype())); - // In-place operation - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclAllReduce(stats, - stats, - 2 * C + 1, - static_cast(dtype), - ncclSum, - comm, - stream)); - } -#endif - - auto *est_mean_data = - mean_out->mutable_data>(ctx.GetPlace()); - auto *est_var_data = - variance_out->mutable_data>(ctx.GetPlace()); - - auto *sv_mean_data = - saved_mean->mutable_data>(ctx.GetPlace()); - auto *sv_inv_var_data = - saved_variance->mutable_data>(ctx.GetPlace()); - - // Note, Input('Mean')/Input('Variance') share variable with - // Output('MeanOut')/Output('VarianceOut') - KeSyncAndMovingStats - <<<(C + block - 1) / block, block, 0, stream>>>(stats, - stats + C, - stats + 2 * C, - C, - momentum, - epsilon, - sv_mean_data, - sv_inv_var_data, - est_mean_data, - est_var_data); - - mean_data = sv_mean_data; - var_data = stats + C; - } - - int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; - if (layout == framework::DataLayout::kNCHW) { - KeNormAffine - <<>>(x_d, - s_d, - b_d, - mean_data, - var_data, - epsilon, - C, - H * W * D, - x_numel, - y_d); - } else { - KeNormAffine - <<>>(x_d, - s_d, - b_d, - mean_data, - var_data, - epsilon, - C, - H * W * D, - x_numel, - y_d); - } -} - -template -__global__ void KeBackwardLocalStats(const T *dy, - const T *x, - const BatchNormParamType *means, - int N, - int M, - int C, - BatchNormParamType *sum_dy_prod) { - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int k = blockIdx.x; k < C; k += gridDim.x) { - BatchNormParamType sum1 = 0.; - BatchNormParamType sum2 = 0.; - auto mean = means[k]; - for (int i = threadIdx.x; i < N * M; i += blockDim.x) { - int id = layout == framework::DataLayout::kNCHW - ? (i / M) * C * M + k * M + i % M - : i * C + k; - auto g = static_cast>(dy[id]); - sum1 += g; - auto x_i = static_cast>(x[id]); - sum2 += g * (x_i - mean); - } - - __syncthreads(); - auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - sum_dy_prod[k] = out; - } - out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - sum_dy_prod[k + C] = out; - } - } - if (blockIdx.x == 0 && threadIdx.x == 0) { - sum_dy_prod[2 * C] = 1.0; - } -} - -template -static __global__ void KeBNBackwardScaleBias( - const T *dy, - const T *x, - const BatchNormParamType *mean, - const BatchNormParamType *inv_variance, - const double epsilon, - const int N, - const int C, - const int HxW, - BatchNormParamType *dscale, - BatchNormParamType *dbias) { - const int outer_size = C; - const int inner_size = N * HxW; - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - BatchNormParamType ds_sum = 0.; - BatchNormParamType db_sum = 0.; - - auto inv_var_i = inv_variance[i]; - auto mean_i = mean[i]; - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int id = layout == framework::DataLayout::kNCHW - ? ((j / HxW) * C + i) * HxW + (j % HxW) - : j * outer_size + i; - auto x_i = static_cast>(x[id]); - auto dy_i = static_cast>(dy[id]); - ds_sum += dy_i * (x_i - mean_i); - db_sum += dy_i; - } - __syncthreads(); - auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum()); - __syncthreads(); - auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - dscale[i] = os * inv_var_i; - dbias[i] = ob; - } - __syncthreads(); - } -} - -template -static __global__ void KeBNRestoreData(T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const BatchNormParamType *mean, - const BatchNormParamType *sv_inv, - const double epsilon, - int C, - int M, - int num, - const T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (int i = gid; i < num; i += stride) { - const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; - auto y_i = static_cast>(y[i]); - auto x_i = (y_i - bias[c]) / scale[c] / sv_inv[c] + mean[c]; - x[i] = static_cast(x_i); - } -} - -template -static __global__ void KeBNBackwardData( - const T *dy, - const T *x, - const BatchNormParamType *gamma, - const BatchNormParamType *mean, - const BatchNormParamType *inv_variance, - const BatchNormParamType *g_sum_dy, - const BatchNormParamType *g_sum_dy_prod, - const BatchNormParamType *num_dev, - const double epsilon, - const int C, - const int HxW, - const int num, - T *dx) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - auto scale = static_cast>(C) / num; - auto dev_num = num_dev[0]; - for (int i = gid; i < num; i += stride) { - const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; - auto inv_var = inv_variance[c]; - auto s_d = gamma[c]; - auto gvar = - -(g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var); - auto gmean = -(g_sum_dy[c] / dev_num) * s_d * inv_var; - - auto x_i = static_cast>(x[i]); - auto dy_i = static_cast>(dy[i]); - auto dx_i = - dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]); - dx[i] = static_cast(dx_i); - } -} - -template -void SyncBatchNormGradFunctor(const framework::ExecutionContext &ctx, - const DataLayout layout, - const framework::Tensor *scale, - const framework::Tensor *bias, - framework::Tensor *d_x, - const framework::Tensor *d_y, - framework::Tensor *d_scale, - framework::Tensor *d_bias, - const framework::Tensor *mean, - const framework::Tensor *variance, - const double epsilon) { - // sync_batch_norm with inplace as false will take X as grad input, which - // is same as cuDNN batch_norm backward calculation, batch_norm - // with inplace as true only take Y as input and X should be calculate - // by inverse operation of batch_norm on Y - const Tensor *x; - bool is_inplace; - if (ctx.HasInput("Y")) { - x = ctx.Input("Y"); - is_inplace = true; - } else { - x = ctx.Input("X"); - is_inplace = false; - } - - const auto &x_dims = x->dims(); - - PADDLE_ENFORCE_GE(x_dims.size(), - 2, - platform::errors::InvalidArgument( - "The Input X dim size should be larger than 1.")); - PADDLE_ENFORCE_LE(x_dims.size(), - 5, - platform::errors::InvalidArgument( - "The Input X dim size should be less than 6.")); - - int N, C, H, W, D; - ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); - PADDLE_ENFORCE_EQ(scale->dims()[0], - C, - platform::errors::InvalidArgument( - "Expected first dim for input parameter(scale) of " - "OP(sync_batch_norm) be (%d), but given (%d).", - C, - scale->dims()[0])); - - d_x->mutable_data(ctx.GetPlace()); - if (d_scale && d_bias) { - d_scale->mutable_data>(ctx.GetPlace()); - d_bias->mutable_data>(ctx.GetPlace()); - } - PADDLE_ENFORCE_EQ(scale->dims().size(), - 1UL, - platform::errors::InvalidArgument( - "Expected rank for input parameter(scale) of " - "OP(sync_batch_norm) be (1), but given (%d).", - scale->dims().size())); - - std::vector dims; - std::vector strides; - if (layout == DataLayout::kNCHW) { - dims = {N, C, H, W, D}; - strides = {C * H * W * D, H * W * D, W * D, D, 1}; - } else { - dims = {N, C, H, W, D}; - strides = {H * W * C * D, 1, W * D * C, D * C, C}; - } - const T *x_d = x->data(); - auto px = *x; - const T *dy_d = d_y->data(); - - auto &dev_ctx = ctx.cuda_device_context(); - auto stream = dev_ctx.stream(); - - const auto *saved_mean = mean->data>(); - const auto *saved_inv_var = variance->data>(); - const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); - auto alloc_ptr = memory::Alloc(dev_ctx, bytes); - auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); - - const int block = 512; - const int threads = 256; - int x_numel = x->numel(); - int fsize = H * W * D; - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - int grid = std::min(C, (max_threads + threads - 1) / threads); - int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; - - if (is_inplace) { - if (layout == framework::DataLayout::kNCHW) { - KeBNRestoreData - <<>>(px.mutable_data(ctx.GetPlace()), - scale->data>(), - bias->data>(), - saved_mean, - saved_inv_var, - epsilon, - C, - H * W * D, - x_numel, - x->data()); - } else { - KeBNRestoreData - <<>>(px.mutable_data(ctx.GetPlace()), - scale->data>(), - bias->data>(), - saved_mean, - saved_inv_var, - epsilon, - C, - H * W * D, - x_numel, - x->data()); - } - } - - if (layout == framework::DataLayout::kNCHW) { - KeBackwardLocalStats - <<>>( - dy_d, x_d, saved_mean, N, fsize, C, stats); - } else { - KeBackwardLocalStats - <<>>( - dy_d, x_d, saved_mean, N, fsize, C, stats); - } - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto *comm = dev_ctx.nccl_comm(); - if (comm) { - int dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(scale->dtype())); - // In-place operation - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclAllReduce(stats, - stats, - 2 * C + 1, - static_cast(dtype), - ncclSum, - comm, - stream)); - } -#endif - - if (layout == framework::DataLayout::kNCHW) { - if (d_scale && d_bias) { - KeBNBackwardScaleBias - <<>>(dy_d, - x_d, - saved_mean, - saved_inv_var, - epsilon, - N, - C, - fsize, - d_scale->data>(), - d_bias->data>()); - } - if (d_x) { - KeBNBackwardData - <<>>(dy_d, - x_d, - scale->data>(), - saved_mean, - saved_inv_var, - stats, - stats + C, - stats + 2 * C, - epsilon, - C, - fsize, - x->numel(), - d_x->data()); - } - } else { - if (d_scale && d_bias) { - KeBNBackwardScaleBias - <<>>(dy_d, - x_d, - saved_mean, - saved_inv_var, - epsilon, - N, - C, - fsize, - d_scale->data>(), - d_bias->data>()); - } - if (d_x) { - KeBNBackwardData - <<>>(dy_d, - x_d, - scale->data>(), - saved_mean, - saved_inv_var, - stats, - stats + C, - stats + 2 * C, - epsilon, - C, - fsize, - x->numel(), - d_x->data()); - } - } -} - -template -class SyncBatchNormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override; -}; - -// Deriving the Gradient for the Backward Pass of Batch Normalization -// https://kevinzakka.github.io/2016/09/14/batch_normalization/ -template -class SyncBatchNormGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 453a0d9c1669039eed722752e9364284d942782d..ed08fe48ee8497fb6ba62840b3ba981b258387df 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2075,6 +2075,16 @@ func : swish backward : swish_grad +# sync_batch_norm +- api : sync_batch_norm + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) + output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + infer_meta : + func : BatchNormInferMeta + kernel : + func : sync_batch_norm + backward : sync_batch_norm_grad + # take_along_axis - api : take_along_axis args : (Tensor x, Tensor index, int axis) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 50aa57a3845cd3f105fc708f74effba5aa636534..91464ac769f77eddbb8084fc71d0a53abbc64dca 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2085,6 +2085,18 @@ func : swish_grad inplace : (out_grad -> x_grad) +- backward_api : sync_batch_norm_grad + forward : sync_batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space) + args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) + output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [x, scale, bias] + kernel : + func : sync_batch_norm_grad + data_type : out_grad + optional : mean_out, variance_out, reserve_space + - backward_api : take_along_axis_grad forward : take_along_axis (Tensor x, Tensor index, int axis) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, int axis) diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..ba5020d08bd0f6953eb5ba31e1dd1f5287c0435c --- /dev/null +++ b/paddle/phi/kernels/gpu/sync_batch_norm_grad_kernel.cu @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sync_batch_norm_grad_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" + +namespace phi { + +template +void SyncBatchNormGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const DenseTensor& y_grad, + float momentum, + float epsilon_f, + const std::string& data_layout_str, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad) { + SyncBatchNormGradFunctor(ctx, + &x, + nullptr, + scale, + bias, + saved_mean, + saved_variance, + y_grad, + epsilon_f, + data_layout_str, + x_grad, + scale_grad, + bias_grad); +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(sync_batch_norm_grad, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormGradKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(sync_batch_norm_grad, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormGradKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a1d4b681ca0536cd4ad59c872aa3a5e4c96bc454 --- /dev/null +++ b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu @@ -0,0 +1,190 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sync_batch_norm_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" + +namespace phi { + +template +void SyncBatchNormKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &scale, + const DenseTensor &bias, + const DenseTensor &mean, + const DenseTensor &variance, + float momentum, + float epsilon_f, + const std::string &data_layout_str, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor *y, + DenseTensor *mean_out, + DenseTensor *variance_out, + DenseTensor *saved_mean, + DenseTensor *saved_variance, + DenseTensor *reserve_space) { + PADDLE_ENFORCE_EQ(use_global_stats, + false, + phi::errors::InvalidArgument( + "sync_batch_norm doesn't support " + "to set use_global_stats True. Please use batch_norm " + "in this case.")); + + double epsilon = epsilon_f; + const bool trainable_stats = trainable_statistics; + const DataLayout layout = + paddle::framework::StringToDataLayout(data_layout_str); + bool test_mode = is_test && (!trainable_statistics); + const auto &x_dims = x.dims(); + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), + 5, + phi::errors::InvalidArgument( + "The Input dim size should be less than 6.")); + int N, C, H, W, D; + funcs::ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + int x_numel = x.numel(); + + const T *x_d = x.template data(); + const auto *s_d = scale.template data>(); + const auto *b_d = bias.template data>(); + + T *y_d = ctx.template Alloc(y); + + const BatchNormParamType *mean_data = nullptr; + const BatchNormParamType *var_data = nullptr; + + auto stream = ctx.stream(); + const int block = 512; + int max_threads = ctx.GetMaxPhysicalThreadCount(); + + paddle::memory::AllocationPtr alloc_ptr{nullptr}; + + if (test_mode) { + mean_data = mean.template data>(); + var_data = variance.template data>(); + } else { + // x, x^2, 1, here 1 is used to calc device num + // device num also can be got from platform::DeviceContextPool + const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); + alloc_ptr = paddle::memory::Alloc(ctx, bytes); + + auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); + const int threads = 256; + int grid = std::min(C, (max_threads + threads - 1) / threads); + if (layout == paddle::framework::DataLayout::kNCHW) { + KeLocalStats + <<>>(x_d, N, H * W * D, C, stats); + } else { + KeLocalStats + <<>>(x_d, N, H * W * D, C, stats); + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto *comm = ctx.nccl_comm(); + if (comm) { + int dtype = paddle::platform::ToNCCLDataType( + paddle::framework::TransToProtoVarType(mean_out->dtype())); + // In-place operation + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + stats, + stats, + 2 * C + 1, + static_cast(dtype), + ncclSum, + comm, + stream)); + } +#endif + + auto *est_mean_data = ctx.template Alloc>(mean_out); + auto *est_var_data = + ctx.template Alloc>(variance_out); + + auto *sv_mean_data = ctx.template Alloc>(saved_mean); + auto *sv_inv_var_data = + ctx.template Alloc>(saved_variance); + + // Note, Input('Mean')/Input('Variance') share variable with + // Output('MeanOut')/Output('VarianceOut') + KeSyncAndMovingStats + <<<(C + block - 1) / block, block, 0, stream>>>(stats, + stats + C, + stats + 2 * C, + C, + momentum, + epsilon, + sv_mean_data, + sv_inv_var_data, + est_mean_data, + est_var_data); + + mean_data = sv_mean_data; + var_data = stats + C; + } + + int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; + if (layout == paddle::framework::DataLayout::kNCHW) { + KeNormAffine + <<>>(x_d, + s_d, + b_d, + mean_data, + var_data, + epsilon, + C, + H * W * D, + x_numel, + y_d); + } else { + KeNormAffine + <<>>(x_d, + s_d, + b_d, + mean_data, + var_data, + epsilon, + C, + H * W * D, + x_numel, + y_d); + } +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(sync_batch_norm, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormKernel, + float, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(sync_batch_norm, + GPU, + ALL_LAYOUT, + phi::SyncBatchNormKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_utils.h b/paddle/phi/kernels/gpu/sync_batch_norm_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..37b9bca73a857beae48ca9c9c87f5c07a0e51a26 --- /dev/null +++ b/paddle/phi/kernels/gpu/sync_batch_norm_utils.h @@ -0,0 +1,493 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" + +namespace phi { + +template +using CudnnDataType = paddle::platform::CudnnDataType; +template +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; + +template +__global__ void KeLocalStats( + const T *x, int N, int M, int C, BatchNormParamType *mean_var) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + for (int k = blockIdx.x; k < C; k += gridDim.x) { + BatchNormParamType x_sum = 0.; + BatchNormParamType x2_sum = 0.; + for (int i = threadIdx.x; i < N * M; i += BlockDim) { + int id = layout == DataLayout::kNCHW ? (i / M) * C * M + k * M + i % M + : i * C + k; + auto x_in = static_cast>(x[id]); + x_sum += x_in; + x2_sum += x_in * x_in; + } + __syncthreads(); + auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + mean_var[k] = out / (N * M); + } + out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + mean_var[k + C] = out / (N * M); + } + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + mean_var[2 * C] = static_cast>(1.0); + } +} + +template +__global__ void KeSyncAndMovingStats(BatchNormParamType *means, + BatchNormParamType *variances, + BatchNormParamType *num_dev, + const int C, + const BatchNormParamType momentum, + const double epsilon, + BatchNormParamType *sv_mean_data, + BatchNormParamType *sv_inv_var_data, + BatchNormParamType *moving_means, + BatchNormParamType *moving_variances) { + // sync stats across multi-devices + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < C; i += stride) { + auto mean = means[i] / (*num_dev); + auto var = variances[i] / (*num_dev); + var = var - mean * mean; + + // sync stats + sv_mean_data[i] = mean; + sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon); + variances[i] = var; + + // moving stats + moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum); + moving_variances[i] = + moving_variances[i] * momentum + var * (1. - momentum); + } +} + +template +static __global__ void KeNormAffine(const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const BatchNormParamType *mean, + const BatchNormParamType *variance, + const double epsilon, + const int C, + const int M, + const int num, + T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == DataLayout::kNCHW ? (i / M) % C : i % C; + auto x_i = static_cast>(x[i]); + auto y_i = + (x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; + y[i] = static_cast(y_i); + } +} + +template +__global__ void KeBackwardLocalStats(const T *dy, + const T *x, + const BatchNormParamType *means, + int N, + int M, + int C, + BatchNormParamType *sum_dy_prod) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + for (int k = blockIdx.x; k < C; k += gridDim.x) { + BatchNormParamType sum1 = 0.; + BatchNormParamType sum2 = 0.; + auto mean = means[k]; + for (int i = threadIdx.x; i < N * M; i += blockDim.x) { + int id = layout == DataLayout::kNCHW ? (i / M) * C * M + k * M + i % M + : i * C + k; + auto g = static_cast>(dy[id]); + sum1 += g; + auto x_i = static_cast>(x[id]); + sum2 += g * (x_i - mean); + } + + __syncthreads(); + auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + sum_dy_prod[k] = out; + } + out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + sum_dy_prod[k + C] = out; + } + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + sum_dy_prod[2 * C] = 1.0; + } +} + +template +static __global__ void KeBNBackwardScaleBias( + const T *dy, + const T *x, + const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, + const double epsilon, + const int N, + const int C, + const int HxW, + BatchNormParamType *dscale, + BatchNormParamType *dbias) { + const int outer_size = C; + const int inner_size = N * HxW; + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + BatchNormParamType ds_sum = 0.; + BatchNormParamType db_sum = 0.; + + auto inv_var_i = inv_variance[i]; + auto mean_i = mean[i]; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int id = layout == DataLayout::kNCHW + ? ((j / HxW) * C + i) * HxW + (j % HxW) + : j * outer_size + i; + auto x_i = static_cast>(x[id]); + auto dy_i = static_cast>(dy[id]); + ds_sum += dy_i * (x_i - mean_i); + db_sum += dy_i; + } + __syncthreads(); + auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum()); + __syncthreads(); + auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + dscale[i] = os * inv_var_i; + dbias[i] = ob; + } + __syncthreads(); + } +} + +template +static __global__ void KeBNRestoreData(T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const BatchNormParamType *mean, + const BatchNormParamType *sv_inv, + const double epsilon, + int C, + int M, + int num, + const T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == DataLayout::kNCHW ? (i / M) % C : i % C; + auto y_i = static_cast>(y[i]); + auto x_i = (y_i - bias[c]) / scale[c] / sv_inv[c] + mean[c]; + x[i] = static_cast(x_i); + } +} + +template +static __global__ void KeBNBackwardData( + const T *dy, + const T *x, + const BatchNormParamType *gamma, + const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, + const BatchNormParamType *g_sum_dy, + const BatchNormParamType *g_sum_dy_prod, + const BatchNormParamType *num_dev, + const double epsilon, + const int C, + const int HxW, + const int num, + T *dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + auto scale = static_cast>(C) / num; + auto dev_num = num_dev[0]; + for (int i = gid; i < num; i += stride) { + const int c = layout == DataLayout::kNCHW ? i / HxW % C : i % C; + auto inv_var = inv_variance[c]; + auto s_d = gamma[c]; + auto gvar = + -(g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var); + auto gmean = -(g_sum_dy[c] / dev_num) * s_d * inv_var; + + auto x_i = static_cast>(x[i]); + auto dy_i = static_cast>(dy[i]); + auto dx_i = + dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]); + dx[i] = static_cast(dx_i); + } +} + +template +void SyncBatchNormGradFunctor( + const Context &ctx, + const DenseTensor *input_x, + const DenseTensor *input_y, + const DenseTensor &scale, + const DenseTensor &bias, + // const paddle::optional& mean, + // const paddle::optional& variance, + const DenseTensor &saved_mean, + const DenseTensor &saved_variance, + // const paddle::optional& reserve_space, + const DenseTensor &y_grad, + // float momentum, + float epsilon_f, + const std::string &data_layout_str, + // bool is_test, + // bool use_global_stats, + // bool trainable_statistics, + // bool fuse_with_relu, + DenseTensor *x_grad, + DenseTensor *scale_grad, + DenseTensor *bias_grad) { + double epsilon = static_cast(epsilon_f); + + const DataLayout layout = + paddle::framework::StringToDataLayout(data_layout_str); + + const auto *d_y = &y_grad; + + auto *d_x = x_grad; + auto *d_scale = scale_grad; + auto *d_bias = bias_grad; + + const DenseTensor *x; + bool is_inplace = false; + if (input_y) { + is_inplace = true; + x = input_y; + } else { + x = input_x; + } + const auto &x_dims = x->dims(); + + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "The Input X dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), + 5, + phi::errors::InvalidArgument( + "The Input X dim size should be less than 6.")); + + int N, C, H, W, D; + funcs::ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + PADDLE_ENFORCE_EQ(scale.dims()[0], + C, + phi::errors::InvalidArgument( + "Expected first dim for input parameter(scale) of " + "OP(sync_batch_norm) be (%d), but given (%d).", + C, + scale.dims()[0])); + + ctx.template Alloc(d_x); + if (d_scale && d_bias) { + ctx.template Alloc>(d_scale); + ctx.template Alloc>(d_bias); + } + PADDLE_ENFORCE_EQ(scale.dims().size(), + 1UL, + phi::errors::InvalidArgument( + "Expected rank for input parameter(scale) of " + "OP(sync_batch_norm) be (1), but given (%d).", + scale.dims().size())); + + std::vector dims; + std::vector strides; + if (layout == DataLayout::kNCHW) { + dims = {N, C, H, W, D}; + strides = {C * H * W * D, H * W * D, W * D, D, 1}; + } else { + dims = {N, C, H, W, D}; + strides = {H * W * C * D, 1, W * D * C, D * C, C}; + } + const T *x_d = x->data(); + auto px = *x; + const T *dy_d = d_y->data(); + + auto stream = ctx.stream(); + + const auto *saved_mean_ptr = + saved_mean.template data>(); + const auto *saved_inv_var = + saved_variance.template data>(); + const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); + auto alloc_ptr = paddle::memory::Alloc(ctx, bytes); + auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); + + const int block = 512; + const int threads = 256; + int x_numel = x->numel(); + int fsize = H * W * D; + int max_threads = ctx.GetMaxPhysicalThreadCount(); + int grid = std::min(C, (max_threads + threads - 1) / threads); + int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; + + if (is_inplace) { + if (layout == DataLayout::kNCHW) { + KeBNRestoreData<<>>( + ctx.template Alloc(&px), + scale.template data>(), + bias.template data>(), + saved_mean_ptr, + saved_inv_var, + epsilon, + C, + H * W * D, + x_numel, + x->data()); + } else { + KeBNRestoreData<<>>( + ctx.template Alloc(&px), + scale.template data>(), + bias.template data>(), + saved_mean_ptr, + saved_inv_var, + epsilon, + C, + H * W * D, + x_numel, + x->data()); + } + } + + if (layout == DataLayout::kNCHW) { + KeBackwardLocalStats + <<>>( + dy_d, x_d, saved_mean_ptr, N, fsize, C, stats); + } else { + KeBackwardLocalStats + <<>>( + dy_d, x_d, saved_mean_ptr, N, fsize, C, stats); + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto *comm = ctx.nccl_comm(); + if (comm) { + int dtype = paddle::platform::ToNCCLDataType( + paddle::framework::TransToProtoVarType(scale.dtype())); + // In-place operation + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce( + stats, + stats, + 2 * C + 1, + static_cast(dtype), + ncclSum, + comm, + stream)); + } +#endif + + if (layout == DataLayout::kNCHW) { + if (d_scale && d_bias) { + KeBNBackwardScaleBias + <<>>(dy_d, + x_d, + saved_mean_ptr, + saved_inv_var, + epsilon, + N, + C, + fsize, + d_scale->data>(), + d_bias->data>()); + } + if (d_x) { + KeBNBackwardData<<>>( + dy_d, + x_d, + scale.template data>(), + saved_mean_ptr, + saved_inv_var, + stats, + stats + C, + stats + 2 * C, + epsilon, + C, + fsize, + x->numel(), + d_x->data()); + } + } else { + if (d_scale && d_bias) { + KeBNBackwardScaleBias + <<>>(dy_d, + x_d, + saved_mean_ptr, + saved_inv_var, + epsilon, + N, + C, + fsize, + d_scale->data>(), + d_bias->data>()); + } + if (d_x) { + KeBNBackwardData<<>>( + dy_d, + x_d, + scale.template data>(), + saved_mean_ptr, + saved_inv_var, + stats, + stats + C, + stats + 2 * C, + epsilon, + C, + fsize, + x->numel(), + d_x->data()); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/sync_batch_norm_grad_kernel.h b/paddle/phi/kernels/sync_batch_norm_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..395bec23f1091a4f24bf9d20dd7975118f253056 --- /dev/null +++ b/paddle/phi/kernels/sync_batch_norm_grad_kernel.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SyncBatchNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const paddle::optional& mean, + const paddle::optional& variance, + const DenseTensor& saved_mean, + const DenseTensor& saved_variance, + const paddle::optional& reserve_space, + const DenseTensor& y_grad, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor* x_grad, + DenseTensor* scale_grad, + DenseTensor* bias_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/sync_batch_norm_kernel.h b/paddle/phi/kernels/sync_batch_norm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5071eaabf8653404951566c44b0294ef3b4441c7 --- /dev/null +++ b/paddle/phi/kernels/sync_batch_norm_kernel.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SyncBatchNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + const DenseTensor& bias, + const DenseTensor& mean, + const DenseTensor& variance, + float momentum, + float epsilon, + const std::string& data_layout, + bool is_test, + bool use_global_stats, + bool trainable_statistics, + bool fuse_with_relu, + DenseTensor* y, + DenseTensor* mean_out, + DenseTensor* variance_out, + DenseTensor* saved_mean, + DenseTensor* saved_variance, + DenseTensor* reserve_space); +} // namespace phi diff --git a/paddle/phi/ops/compat/sync_batch_norm_sig.cc b/paddle/phi/ops/compat/sync_batch_norm_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..2595f241ff2335ea419cd2096cfefef570cedf1c --- /dev/null +++ b/paddle/phi/ops/compat/sync_batch_norm_sig.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SyncBatchNormOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("sync_batch_norm", + {"X", "Scale", "Bias", "Mean", "Variance"}, + {"momentum", + "epsilon", + "data_layout", + "is_test", + "use_global_stats", + "trainable_statistics", + "fuse_with_relu"}, + {"Y", + "MeanOut", + "VarianceOut", + "SavedMean", + "SavedVariance", + "ReserveSpace"}); +} + +KernelSignature SyncBatchNormGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("sync_batch_norm_grad", + { + "X", + "Scale", + "Bias", + "Mean", + "Variance", + "SavedMean", + "SavedVariance", + "ReserveSpace", + "Y@GRAD", + }, + {"momentum", + "epsilon", + "data_layout", + "is_test", + "use_global_stats", + "trainable_statistics", + "fuse_with_relu"}, + {"X@GRAD", "Scale@GRAD", "Bias@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(sync_batch_norm, + phi::SyncBatchNormOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(sync_batch_norm_grad, + phi::SyncBatchNormGradOpArgumentMapping); diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index e549859fe626d13e82e170d2f81d820a40765b6f..b9081d0c8e682370d0bf478636e4e28abf61d999 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -49,6 +49,7 @@ from .. import functional as F from paddle import _C_ops from .. import Layer from paddle import in_dynamic_mode +from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -1100,7 +1101,14 @@ class SyncBatchNorm(_BatchNormBase): ### train mode: use mini-batch stats, eval mode: use global stats ### use_global_stats only support False in sync_batch_norm - if in_dynamic_mode(): + if in_dygraph_mode(): + sync_batch_norm_out, _, _, _, _, _ = _C_ops.final_state_sync_batch_norm( + x, self.weight, self.bias, self._mean, self._variance, + self._momentum, self._epsilon, self._data_format, + not self.training, False, False, False) + return sync_batch_norm_out + + elif in_dynamic_mode(): attrs = ("momentum", self._momentum, "epsilon", self._epsilon, "is_test", not self.training, "data_layout", self._data_format, "use_mkldnn", False, "fuse_with_relu", @@ -1109,7 +1117,6 @@ class SyncBatchNorm(_BatchNormBase): sync_batch_norm_out, _, _, _, _, _ = _C_ops.sync_batch_norm( x, self.weight, self.bias, self._mean, self._variance, mean_out, variance_out, *attrs) - return sync_batch_norm_out check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],