diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 4f22d28a450c1a4e72c9580f719442b1a8a0f81b..3467658e894d547d7a0106d6414624b10351d6f3 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -848,7 +848,8 @@ void BatchNormGradMaker::Apply(GradOpPtr op) const { } // used when setting use_global_stats True during training - if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) { + if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats")) || + BOOST_GET_CONST(bool, this->GetAttr("is_test"))) { op->SetInput("Mean", this->Output("MeanOut")); op->SetInput("Variance", this->Output("VarianceOut")); }