From 6f1ec935058a85db9bc36259e02de248cd807b17 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 8 Feb 2023 15:10:28 +0800 Subject: [PATCH] Fix bn performance degradation (#50287) * fix bn performance degradation --- paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 58d05d6075..5e73edcb34 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -783,7 +783,8 @@ void BatchNormGradRawKernel(const Context &ctx, } // CUDNN only support small batch size bool use_native_nhwc = - d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) + d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC && + H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL) : false; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || -- GitLab