提交 d5c6eecb 编写于 作者: L Li Xinqi 提交者: Jinhui Yuan

call cudnnBatchNormalizationForwardInference if trainable == flase (#1197)



Former-commit-id: a21dea46
上级 5fa70913
......@@ -92,7 +92,7 @@ void NormalizationKernel<DeviceType::kGPU, float>::NormalizationCudnnForward(
float* moving_mean = BnInOp2Blob("moving_mean")->mut_dptr<float>();
float* moving_variance = BnInOp2Blob("moving_variance")->mut_dptr<float>();
double epsilon = this->op_conf().normalization_conf().epsilon();
if (Global<JobDesc>::Get()->IsTrain()) {
if (this->op_conf().trainable()) {
InitMovingMeanAndMovingVariance(ctx, BnInOp2Blob, false);
double momentum = this->op_conf().normalization_conf().momentum();
CudaCheck(cudnnBatchNormalizationForwardTraining(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册