未验证 提交 4fe1bb45 编写于 作者: C ceci3 提交者: GitHub

fix mean/variance when is_test=True (#35328)

上级 668bfb35
...@@ -848,7 +848,8 @@ void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const { ...@@ -848,7 +848,8 @@ void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const {
} }
// used when setting use_global_stats True during training // used when setting use_global_stats True during training
if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) { if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats")) ||
BOOST_GET_CONST(bool, this->GetAttr("is_test"))) {
op->SetInput("Mean", this->Output("MeanOut")); op->SetInput("Mean", this->Output("MeanOut"));
op->SetInput("Variance", this->Output("VarianceOut")); op->SetInput("Variance", this->Output("VarianceOut"));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册