From eb6107402435ae6bc50da1a8d17e17f152b5e879 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 10 Feb 2023 10:12:14 +0800 Subject: [PATCH] [cherry-pick] Fix bn performance degradation (#50382) att, cherry-pick: #48563 , #50287 --- .../phi/kernels/gpu/batch_norm_grad_kernel.cu | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 5acccdfcea3..658b42e6165 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx, } // CUDNN only support small batch size bool use_native_nhwc = - d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) + d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC && + H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL) : false; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || @@ -933,6 +934,21 @@ void BatchNormGradRawKernel(const Context &ctx, flag_ptr); } // 2. reduce_sum(x, dy, mean) => dscale, dbias + BatchNormParamType *dscale = nullptr; + BatchNormParamType *dbias = nullptr; + bool with_scale = false; + if (d_scale && d_bias) { + dscale = ctx.template Alloc>(d_scale); + dbias = ctx.template Alloc>(d_bias); + } else { + DenseTensor dscale_mem = + phi::Empty, Context>(ctx, {C}); + DenseTensor dbias_mem = + phi::Empty, Context>(ctx, {C}); + dscale = dscale_mem.data>(); + dbias = dbias_mem.data>(); + } + BNBackward2DChannelLastStage2 <<>>( transformed_d_y.template data(), @@ -944,8 +960,8 @@ void BatchNormGradRawKernel(const Context &ctx, H * W * D, epsilon, block_data_ptr, - ctx.template Alloc>(d_scale), - ctx.template Alloc>(d_bias), + dscale, + dbias, flag_ptr); // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx @@ -954,8 +970,8 @@ void BatchNormGradRawKernel(const Context &ctx, transformed_d_y.template data(), transformed_x.template data(), scale.template data>(), - d_scale->data>(), - d_bias->data>(), + dscale, + dbias, mean_ptr, variance_ptr, C, @@ -1165,6 +1181,7 @@ void BatchNormGradRawKernel(const Context &ctx, paddle::platform::dynload::cudnnDestroyTensorDescriptor( bn_param_desc_)); #endif + } else { const auto *running_mean = mean.get_ptr(); const auto *running_var = variance.get_ptr(); -- GitLab