提交 dcd87fd6 编写于 作者: Q qingqing01 提交者: Yu Yang

fix CUDNN_VERSION for backward of CudnnBatchNormLayer (#61)

上级 674d69ce
......@@ -114,27 +114,30 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
} else {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
}
#if CUDNN_VERSION < 5000
// because of the different api of cudnn v4 and v5.
if (weight_->getWGrad()) {
create(tmpWGrad_, 1, channels_, &gammaGrad);
}
if (biases_ && biases_->getWGrad()) {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
if (hl_get_cudnn_lib_version() < 5000) {
if (weight_->getWGrad()) {
create(tmpWGrad_, 1, channels_, &gammaGrad);
}
if (biases_ && biases_->getWGrad()) {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
}
}
#endif
hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad,
ioDesc_, inGrad, bnParamDesc_,
gamma, gammaGrad, betaGrad,
EPS, savedMean, savedInvVar);
#if CUDNN_VERSION < 5000
// because of the different api of cudnn v4 and v5.
if (weight_->getWGrad() && biases_->getWGrad()) {
weight_->getWGrad()->add(*tmpWGrad_);
biases_->getWGrad()->add(*tmpBiasGrad_);
if (hl_get_cudnn_lib_version() < 5000) {
if (weight_->getWGrad() && biases_->getWGrad()) {
weight_->getWGrad()->add(*tmpWGrad_);
biases_->getWGrad()->add(*tmpBiasGrad_);
}
}
#endif
{
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
biases_->getParameterPtr()->incUpdate(callback);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册