diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 1758463141cb8f9510b7d7f8d62e69f0ce0e4013..42e1e2e7463c7753fbf205c88442db63733754ea 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -225,11 +225,17 @@ class BatchNormKernel #elif CUDNN_VERSION_MIN(7, 0, 1) if (FLAGS_cudnn_batchnorm_spatial_persistent) { mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } else if (H == 1 && W == 1) { + mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; } else { mode_ = CUDNN_BATCHNORM_SPATIAL; } #else - mode_ = CUDNN_BATCHNORM_SPATIAL; + if (H == 1 && W == 1) { + mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; + } else { + mode_ = CUDNN_BATCHNORM_SPATIAL; + } #endif // CUDNN_VERSION_MIN(7, 0, 1) VLOG(3) << "Setting descriptors."; @@ -989,11 +995,17 @@ class BatchNormGradKernel #elif CUDNN_VERSION_MIN(7, 0, 1) if (FLAGS_cudnn_batchnorm_spatial_persistent) { mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } else if (H == 1 && W == 1) { + mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; } else { mode_ = CUDNN_BATCHNORM_SPATIAL; } #else - mode_ = CUDNN_BATCHNORM_SPATIAL; + if (H == 1 && W == 1) { + mode_ = CUDNN_BATCHNORM_PER_ACTIVATION; + } else { + mode_ = CUDNN_BATCHNORM_SPATIAL; + } #endif // CUDNN_VERSION_MIN(7, 0, 1) #ifdef PADDLE_WITH_HIP