diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index 4cd502d25f64562f32c197d7d49ce8f15e11ec0b..989755303d602ec65eaaa36b5a4d9ef58ba7380c 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -137,7 +137,7 @@ SymbolVarArray BatchNormForward::make(SymbolVar x, cg::OperatorNodeBase::NodeProp* BatchNormForward::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); - if (need_stats()) { + if (need_stats() && m_force_inplace) { ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); } return ret;