From c42cbb145c185869f7d4020666b5fcae9f59bcc0 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 8 Aug 2022 18:10:17 +0800 Subject: [PATCH] BN1D inference support large batch_size (#44977) --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 74 +++++++++++++++------ 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 67698369278..5c6fd04c15e 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -691,6 +691,9 @@ void BatchNormKernel(const Context &ctx, auto handle = ctx.cudnn_handle(); + const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; + const size_t CUDNN_SPATIAL_THRESHOLD = 880801; + // Now, depending on whether we are running test or not, we have two paths. // It is training mode when it's not reference AND not using pre-trained // model. @@ -793,23 +796,58 @@ void BatchNormKernel(const Context &ctx, // est_var->template data>())), // epsilon)); #else - PADDLE_ENFORCE_GPU_SUCCESS( - paddle::platform::dynload::cudnnBatchNormalizationForwardInference( - handle, - // Note: PERSISTENT not implemented for inference - CUDNN_BATCHNORM_SPATIAL, - CudnnDataType::kOne(), - CudnnDataType::kZero(), - data_desc_, - transformed_x.template data(), - data_desc_, - ctx.template Alloc(&transformed_y), - bn_param_desc_, - scale.template data>(), - bias.template data>(), - est_mean->template data>(), - est_var->template data>(), - epsilon)); + const bool use_native_kernel = + ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || + (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); + if (use_native_kernel) { + const int block_size = 256; + const int grid_size = (N * C * H * W * D + block_size - 1) / block_size; + if (compute_format == DataLayout::kNCHW) { + BNForwardInference + <<>>( + transformed_x.template data(), + est_mean->template data>(), + est_var->template data>(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + transformed_y.template data()); + } else { + BNForwardInference + <<>>( + transformed_x.template data(), + est_mean->template data>(), + est_var->template data>(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + transformed_y.template data()); + } + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cudnnBatchNormalizationForwardInference( + handle, + // Note: PERSISTENT not implemented for inference + CUDNN_BATCHNORM_SPATIAL, + CudnnDataType::kOne(), + CudnnDataType::kZero(), + data_desc_, + transformed_x.template data(), + data_desc_, + ctx.template Alloc(&transformed_y), + bn_param_desc_, + scale.template data>(), + bias.template data>(), + est_mean->template data>(), + est_var->template data>(), + epsilon)); + } #endif } else { // if MomentumTensor is set, use MomentumTensor value, momentum @@ -909,8 +947,6 @@ void BatchNormKernel(const Context &ctx, // BatchNormParamType>(ctx.GetPlace())))); #else // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; - const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; - const size_t CUDNN_SPATIAL_THRESHOLD = 880801; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); -- GitLab