diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 58d05d60758161ebda5b31ccf3dfcc17de32c9a4..5e73edcb3488f6299340d27179522f5c41b6ef5b 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -783,7 +783,8 @@ void BatchNormGradRawKernel(const Context &ctx, } // CUDNN only support small batch size bool use_native_nhwc = - d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) + d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC && + H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL) : false; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||