From abcb1e10237e07d88a25bef8ddbf7e8d3632367f Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 2 Aug 2017 23:50:59 +0800 Subject: [PATCH] add the check of cudnn version in cudnnBatchNorm --- paddle/cuda/src/hl_cuda_cudnn.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index c53a56368..7ad8a3976 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, real alpha = 1.0f; real beta = 1.0f; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; + + int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size; + if (batch_size > 1024 && g_cudnn_lib_version < 6000) { + LOG(INFO) << " To process current batch data with size " << batch_size + << " (>1024), cudnnBatchNorm requires cuDNN version >= 6000." + << " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED," + << " just recompile PaddlePaddle with cuDNN >= 6000, replacing" + << " current version " << g_cudnn_lib_version; + } CHECK_CUDNN( dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, mode, -- GitLab