diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 252fbe7d21b74406fa2a89d10f1ed77230555f09..cfad86506c9099d08b92471edf27cf437ba85dd9 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -858,15 +858,20 @@ void BatchNormGradRawKernel(const Context &ctx, // ctx.GetPlace()), // epsilon, saved_mean_data, saved_var_data)); #else - // CUDNN only support small batch size - // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; - const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; - const size_t CUDNN_SPATIAL_THRESHOLD = 880801; - const bool use_native_kernel = - ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || - (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); - if (use_native_kernel) { - if (x_dims.size() == 2) { + } + // CUDNN only support small batch size + // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; + const size_t CUDNN_SPATIAL_THRESHOLD = 880801; + bool use_native_nhwc = + d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) + : false; + const bool use_native_kernel = + ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || + (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); + if (use_native_nhwc || (d_x && d_scale && d_bias)) { + if (use_native_kernel || use_native_nhwc) { + if (x_dims.size() == 2 || use_native_nhwc) { dim3 block; dim3 grid; const int block_size = 512; @@ -937,6 +942,21 @@ void BatchNormGradRawKernel(const Context &ctx, flag_ptr); } // 2. reduce_sum(x, dy, mean) => dscale, dbias + BatchNormParamType *dscale = nullptr; + BatchNormParamType *dbias = nullptr; + bool with_scale = false; + if (d_scale && d_bias) { + dscale = ctx.template Alloc>(d_scale); + dbias = ctx.template Alloc>(d_bias); + } else { + DenseTensor dscale_mem = + phi::Empty, Context>(ctx, {C}); + DenseTensor dbias_mem = + phi::Empty, Context>(ctx, {C}); + dscale = dscale_mem.data>(); + dbias = dbias_mem.data>(); + } + BNBackward2DChannelLastStage2 <<>>( transformed_d_y.template data(), @@ -948,8 +968,8 @@ void BatchNormGradRawKernel(const Context &ctx, H * W * D, epsilon, block_data_ptr, - ctx.template Alloc>(d_scale), - ctx.template Alloc>(d_bias), + dscale, + dbias, flag_ptr); // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx @@ -958,8 +978,8 @@ void BatchNormGradRawKernel(const Context &ctx, transformed_d_y.template data(), transformed_x.template data(), scale.template data>(), - d_scale->data>(), - d_bias->data>(), + dscale, + dbias, mean_ptr, variance_ptr, C, @@ -1169,6 +1189,7 @@ void BatchNormGradRawKernel(const Context &ctx, paddle::platform::dynload::cudnnDestroyTensorDescriptor( bn_param_desc_)); #endif + } else { const auto *running_mean = mean.get_ptr(); const auto *running_var = variance.get_ptr();