diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp index e1762e8d360dea6af5763966dc68b479b54df5df..cef8772fc254f98d676e6fb89042487315280c61 100644 --- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp +++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp @@ -114,27 +114,30 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { } else { create(tmpBiasGrad_, 1, channels_, &betaGrad); } -#if CUDNN_VERSION < 5000 + // because of the different api of cudnn v4 and v5. - if (weight_->getWGrad()) { - create(tmpWGrad_, 1, channels_, &gammaGrad); - } - if (biases_ && biases_->getWGrad()) { - create(tmpBiasGrad_, 1, channels_, &betaGrad); + if (hl_get_cudnn_lib_version() < 5000) { + if (weight_->getWGrad()) { + create(tmpWGrad_, 1, channels_, &gammaGrad); + } + if (biases_ && biases_->getWGrad()) { + create(tmpBiasGrad_, 1, channels_, &betaGrad); + } } -#endif + hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad, ioDesc_, inGrad, bnParamDesc_, gamma, gammaGrad, betaGrad, EPS, savedMean, savedInvVar); -#if CUDNN_VERSION < 5000 // because of the different api of cudnn v4 and v5. - if (weight_->getWGrad() && biases_->getWGrad()) { - weight_->getWGrad()->add(*tmpWGrad_); - biases_->getWGrad()->add(*tmpBiasGrad_); + if (hl_get_cudnn_lib_version() < 5000) { + if (weight_->getWGrad() && biases_->getWGrad()) { + weight_->getWGrad()->add(*tmpWGrad_); + biases_->getWGrad()->add(*tmpBiasGrad_); + } } -#endif + { REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); biases_->getParameterPtr()->incUpdate(callback);