未验证 提交 3629bf4f 编写于 作者: L Li Min 提交者: GitHub

replace spatial with per_activation mode for bn op to improve perf (#33887)

上级 eae31856
...@@ -225,11 +225,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -225,11 +225,17 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
#elif CUDNN_VERSION_MIN(7, 0, 1) #elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else if (H == 1 && W == 1) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else { } else {
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
} }
#else #else
mode_ = CUDNN_BATCHNORM_SPATIAL; if (H == 1 && W == 1) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL;
}
#endif // CUDNN_VERSION_MIN(7, 0, 1) #endif // CUDNN_VERSION_MIN(7, 0, 1)
VLOG(3) << "Setting descriptors."; VLOG(3) << "Setting descriptors.";
...@@ -989,11 +995,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -989,11 +995,17 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
#elif CUDNN_VERSION_MIN(7, 0, 1) #elif CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else if (H == 1 && W == 1) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else { } else {
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
} }
#else #else
mode_ = CUDNN_BATCHNORM_SPATIAL; if (H == 1 && W == 1) {
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode_ = CUDNN_BATCHNORM_SPATIAL;
}
#endif // CUDNN_VERSION_MIN(7, 0, 1) #endif // CUDNN_VERSION_MIN(7, 0, 1)
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册