提交 9a9e0597 编写于 作者: L liaogang

Merge remote-tracking branch 'upstream/master'

...@@ -114,27 +114,30 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { ...@@ -114,27 +114,30 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
} else { } else {
create(tmpBiasGrad_, 1, channels_, &betaGrad); create(tmpBiasGrad_, 1, channels_, &betaGrad);
} }
#if CUDNN_VERSION < 5000
// because of the different api of cudnn v4 and v5. // because of the different api of cudnn v4 and v5.
if (weight_->getWGrad()) { if (hl_get_cudnn_lib_version() < 5000) {
create(tmpWGrad_, 1, channels_, &gammaGrad); if (weight_->getWGrad()) {
} create(tmpWGrad_, 1, channels_, &gammaGrad);
if (biases_ && biases_->getWGrad()) { }
create(tmpBiasGrad_, 1, channels_, &betaGrad); if (biases_ && biases_->getWGrad()) {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
}
} }
#endif
hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad, hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad,
ioDesc_, inGrad, bnParamDesc_, ioDesc_, inGrad, bnParamDesc_,
gamma, gammaGrad, betaGrad, gamma, gammaGrad, betaGrad,
EPS, savedMean, savedInvVar); EPS, savedMean, savedInvVar);
#if CUDNN_VERSION < 5000
// because of the different api of cudnn v4 and v5. // because of the different api of cudnn v4 and v5.
if (weight_->getWGrad() && biases_->getWGrad()) { if (hl_get_cudnn_lib_version() < 5000) {
weight_->getWGrad()->add(*tmpWGrad_); if (weight_->getWGrad() && biases_->getWGrad()) {
biases_->getWGrad()->add(*tmpBiasGrad_); weight_->getWGrad()->add(*tmpWGrad_);
biases_->getWGrad()->add(*tmpBiasGrad_);
}
} }
#endif
{ {
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
biases_->getParameterPtr()->incUpdate(callback); biases_->getParameterPtr()->incUpdate(callback);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册