提交 7dbc092c 编写于 作者: Q qingqing01 提交者: emailweixu

fix cudnn version number for batch norm. (#71)

* fix CUDNN_VERSION for backward of CudnnBatchNormLayer

* fix cudnn version number for BatchNorm
上级 d6cc5203
...@@ -150,7 +150,7 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -150,7 +150,7 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
// APIs available after R4: // APIs available after R4:
#if CUDNN_VERSION >= 4000 #if CUDNN_VERSION >= 4007
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \ #define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
__macro(cudnnBatchNormalizationForwardTraining) \ __macro(cudnnBatchNormalizationForwardTraining) \
__macro(cudnnBatchNormalizationForwardInference) \ __macro(cudnnBatchNormalizationForwardInference) \
...@@ -999,7 +999,7 @@ void hl_batch_norm_forward_training(hl_tensor_descriptor inputDesc, ...@@ -999,7 +999,7 @@ void hl_batch_norm_forward_training(hl_tensor_descriptor inputDesc,
double epsilon, double epsilon,
real *savedMean, real *savedMean,
real *savedVar) { real *savedVar) {
#if CUDNN_VERSION >= 4000 #if CUDNN_VERSION >= 4007
if ((NULL != runningMean && NULL == runningInvVar) || if ((NULL != runningMean && NULL == runningInvVar) ||
(NULL == runningMean && NULL != runningInvVar)) { (NULL == runningMean && NULL != runningInvVar)) {
LOG(FATAL) << "runningMean and runningInvVar can be NULL " LOG(FATAL) << "runningMean and runningInvVar can be NULL "
...@@ -1024,7 +1024,7 @@ void hl_batch_norm_forward_training(hl_tensor_descriptor inputDesc, ...@@ -1024,7 +1024,7 @@ void hl_batch_norm_forward_training(hl_tensor_descriptor inputDesc,
CHECK_SYNC("hl_batch_norm_forward_training failed"); CHECK_SYNC("hl_batch_norm_forward_training failed");
#else #else
LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4000. " LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
<< "But cudnn lib version is " << g_cudnn_lib_version; << "But cudnn lib version is " << g_cudnn_lib_version;
#endif #endif
} }
...@@ -1039,7 +1039,7 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, ...@@ -1039,7 +1039,7 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
real *estimatedMean, real *estimatedMean,
real *estimatedInvVar, real *estimatedInvVar,
double epsilon) { double epsilon) {
#if CUDNN_VERSION >= 4000 #if CUDNN_VERSION >= 4007
cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc); cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc);
cudnnTensorDescriptor_t yDesc = GET_TENSOR_DESCRIPTOR(outputDesc); cudnnTensorDescriptor_t yDesc = GET_TENSOR_DESCRIPTOR(outputDesc);
cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(bnParamDesc); cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(bnParamDesc);
...@@ -1053,7 +1053,7 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, ...@@ -1053,7 +1053,7 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
CHECK_SYNC("hl_batch_norm_forward_inference failed"); CHECK_SYNC("hl_batch_norm_forward_inference failed");
#else #else
LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4000. " LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
<< "But cudnn lib version is " << g_cudnn_lib_version; << "But cudnn lib version is " << g_cudnn_lib_version;
#endif #endif
} }
...@@ -1071,7 +1071,7 @@ void hl_batch_norm_backward(hl_tensor_descriptor inputDesc, ...@@ -1071,7 +1071,7 @@ void hl_batch_norm_backward(hl_tensor_descriptor inputDesc,
double epsilon, double epsilon,
real *savedMean, real *savedMean,
real *savedInvVar) { real *savedInvVar) {
#if CUDNN_VERSION >= 4000 #if CUDNN_VERSION >= 4007
if ((NULL != savedMean && NULL == savedInvVar) || if ((NULL != savedMean && NULL == savedInvVar) ||
(NULL == savedMean && NULL != savedInvVar)) { (NULL == savedMean && NULL != savedInvVar)) {
LOG(FATAL) << "savedMean and savedVar can be NULL " LOG(FATAL) << "savedMean and savedVar can be NULL "
...@@ -1087,16 +1087,14 @@ void hl_batch_norm_backward(hl_tensor_descriptor inputDesc, ...@@ -1087,16 +1087,14 @@ void hl_batch_norm_backward(hl_tensor_descriptor inputDesc,
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
CHECK_CUDNN(dynload::cudnnBatchNormalizationBackward( CHECK_CUDNN(dynload::cudnnBatchNormalizationBackward(
t_resource.cudnn_handle, mode, &alpha, &beta, t_resource.cudnn_handle, mode, &alpha, &beta,
#if CUDNN_VERSION >= 5000
&alpha, &beta, &alpha, &beta,
#endif
xDesc, input, dyDesc, outGrad, dxDesc, inGrad, xDesc, input, dyDesc, outGrad, dxDesc, inGrad,
bnDesc, scale, scaleGrad, biasGrad, epsilon, bnDesc, scale, scaleGrad, biasGrad, epsilon,
savedMean, savedInvVar)); savedMean, savedInvVar));
CHECK_SYNC("hl_batch_norm_backward failed"); CHECK_SYNC("hl_batch_norm_backward failed");
#else #else
LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4000. " LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
<< "But cudnn lib version is " << g_cudnn_lib_version; << "But cudnn lib version is " << g_cudnn_lib_version;
#endif #endif
} }
...@@ -115,29 +115,11 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) { ...@@ -115,29 +115,11 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
create(tmpBiasGrad_, 1, channels_, &betaGrad); create(tmpBiasGrad_, 1, channels_, &betaGrad);
} }
// because of the different api of cudnn v4 and v5.
if (hl_get_cudnn_lib_version() < 5000) {
if (weight_->getWGrad()) {
create(tmpWGrad_, 1, channels_, &gammaGrad);
}
if (biases_ && biases_->getWGrad()) {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
}
}
hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad, hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad,
ioDesc_, inGrad, bnParamDesc_, ioDesc_, inGrad, bnParamDesc_,
gamma, gammaGrad, betaGrad, gamma, gammaGrad, betaGrad,
EPS, savedMean, savedInvVar); EPS, savedMean, savedInvVar);
// because of the different api of cudnn v4 and v5.
if (hl_get_cudnn_lib_version() < 5000) {
if (weight_->getWGrad() && biases_->getWGrad()) {
weight_->getWGrad()->add(*tmpWGrad_);
biases_->getWGrad()->add(*tmpBiasGrad_);
}
}
{ {
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
biases_->getParameterPtr()->incUpdate(callback); biases_->getParameterPtr()->incUpdate(callback);
......
...@@ -1614,7 +1614,7 @@ class BatchNormLayer(LayerBase): ...@@ -1614,7 +1614,7 @@ class BatchNormLayer(LayerBase):
# Also based on cudnn version. # Also based on cudnn version.
use_cudnn = use_gpu and batch_norm_type != "batch_norm" and \ use_cudnn = use_gpu and batch_norm_type != "batch_norm" and \
((not parallel_nn) or self.config.device > -1) and \ ((not parallel_nn) or self.config.device > -1) and \
cudnn_version >= 4000 cudnn_version >= 4007
self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm" self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm"
super(BatchNormLayer, self).__init__(name, self.layer_type, 0, super(BatchNormLayer, self).__init__(name, self.layer_type, 0,
active_type=active_type, active_type=active_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册