From dcd87fd68971d85b7faa90c2dfee6b6534ec2315 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 12 Sep 2016 09:58:35 +0800 Subject: [PATCH] fix CUDNN_VERSION for backward of CudnnBatchNormLayer (#61) --- paddle/gserver/layers/CudnnBatchNormLayer.cpp | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index e1762e8d360..cef8772fc25 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -114,27 +114,30 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { } else { create(tmpBiasGrad_, 1, channels_, &betaGrad); } -#if CUDNN_VERSION < 5000 + // because of the different api of cudnn v4 and v5. - if (weight_->getWGrad()) { - create(tmpWGrad_, 1, channels_, &gammaGrad); - } - if (biases_ && biases_->getWGrad()) { - create(tmpBiasGrad_, 1, channels_, &betaGrad); + if (hl_get_cudnn_lib_version() < 5000) { + if (weight_->getWGrad()) { + create(tmpWGrad_, 1, channels_, &gammaGrad); + } + if (biases_ && biases_->getWGrad()) { + create(tmpBiasGrad_, 1, channels_, &betaGrad); + } } -#endif + hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad, ioDesc_, inGrad, bnParamDesc_, gamma, gammaGrad, betaGrad, EPS, savedMean, savedInvVar); -#if CUDNN_VERSION < 5000 // because of the different api of cudnn v4 and v5. - if (weight_->getWGrad() && biases_->getWGrad()) { - weight_->getWGrad()->add(*tmpWGrad_); - biases_->getWGrad()->add(*tmpBiasGrad_); + if (hl_get_cudnn_lib_version() < 5000) { + if (weight_->getWGrad() && biases_->getWGrad()) { + weight_->getWGrad()->add(*tmpWGrad_); + biases_->getWGrad()->add(*tmpBiasGrad_); + } } -#endif + { REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); biases_->getParameterPtr()->incUpdate(callback); -- GitLab