diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index 37cb4fd46d2ce3bafd8040d7534f3890e8196932..149148a4a497abd7b1b092fdf0211468c1853f1f 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -47,13 +47,17 @@ class XORNet(M.Module): self.num_class = 2 super().__init__() self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True) + self.bn0 = M.BatchNorm1d(self.mid_dim) self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True) + self.bn1 = M.BatchNorm1d(self.mid_dim) self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True) def forward(self, x): x = self.fc0(x) + x = self.bn0(x) x = F.tanh(x) x = self.fc1(x) + x = self.bn1(x) x = F.tanh(x) x = self.fc2(x) return x diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index 52a9c774ab79e802b3f611c99a95048db3eea5cc..5d403fbf4ffadc73837e7edd98bb7ac343d923d9 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -44,7 +44,7 @@ BatchNormForward::BatchNormForward(VarNode *x, m_force_inplace = false; } - if (m_force_inplace) { + if (m_force_inplace && param.fwd_mode == Param::FwdMode::TRAINING) { auto check_dest = [&](VarNode* dest) { auto dest_opr = dest->owner_opr(); mgb_throw_if(!(dest_opr->same_type() || @@ -62,7 +62,14 @@ BatchNormForward::BatchNormForward(VarNode *x, add_input({x, scale, bias, mean, variance}); - if (m_force_inplace) { + if (param.fwd_mode == Param::FwdMode::INFERENCE) { + auto mark_empty_var = [&](VarNode *var) { + var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .add_flag(VarNode::Flag::VOLATILE_CONTENT); + }; + mark_empty_var(output(0)); + mark_empty_var(output(1)); + } else if (m_force_inplace) { output(0)-> set_fwd_in2out_writable_force(input(3)). add_flag(VarNode::Flag::NO_MEM_RECLAIM); @@ -129,7 +136,7 @@ SymbolVarArray BatchNormForward::make(SymbolVar x, cg::OperatorNodeBase::NodeProp* BatchNormForward::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); - if (input().size() == 5) { + if (need_stats()) { ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); } return ret; @@ -140,7 +147,7 @@ void BatchNormForward::scn_do_execute() { auto &&y = output(4)->dev_tensor(); mgb_assert(x.layout().is_contiguous() && y.layout().is_contiguous()); - if (input().size() == 5) { // need running mean/variance + if (need_stats()) { auto &&o0 = output(0)->dev_tensor(), &&o1 = output(1)->dev_tensor(), &&i0 = input(3)->dev_tensor(), @@ -164,8 +171,14 @@ void BatchNormForward::scn_do_execute() { } auto scale = input(1)->dev_tensor().as_megdnn(); auto bias = input(2)->dev_tensor().as_megdnn(); - auto mean = output(0)->dev_tensor().as_megdnn(); - auto variance = output(1)->dev_tensor().as_megdnn(); + megdnn::TensorND mean, variance; + if (param().fwd_mode == Param::FwdMode::INFERENCE) { + mean = input(3)->dev_tensor().as_megdnn(); + variance = input(4)->dev_tensor().as_megdnn(); + } else { + mean = output(0)->dev_tensor().as_megdnn(); + variance = output(1)->dev_tensor().as_megdnn(); + } auto save_mean = output(2)->dev_tensor().as_megdnn(); auto save_variance = output(3)->dev_tensor().as_megdnn(); auto workspace = intl::get_megdnn_workspace_from_var(output().back()); @@ -180,12 +193,11 @@ void BatchNormForward::add_input_layout_constraint() { void BatchNormForward::get_output_var_shape( const TensorShapeArray &inp_shape, TensorShapeArray &out_shape) const { - size_t nr_inp = input().size(); out_shape[4] = inp_shape[0]; for (size_t i = 0; i < 4; ++ i) { out_shape[i] = inp_shape[1]; } - if (nr_inp == 3) { + if (!need_stats()) { out_shape[0] = out_shape[1] = {0}; } } @@ -221,7 +233,7 @@ void BatchNormForward::init_output_dtype() { } void BatchNormForward::mem_plan_fwd_in2out_writable() { - if (!m_force_inplace && input().size() == 5) { + if (need_stats() && !m_force_inplace) { // TODO: testing output(0)->set_fwd_in2out_writable(input(3)); output(1)->set_fwd_in2out_writable(input(4)); diff --git a/src/opr/include/megbrain/opr/dnn/batch_norm.h b/src/opr/include/megbrain/opr/dnn/batch_norm.h index 558f9132aae6d2a6f39fb9eda935cb60a0972c9e..c3ca2f2416a70ec773f50c6e177efc8316fde489 100644 --- a/src/opr/include/megbrain/opr/dnn/batch_norm.h +++ b/src/opr/include/megbrain/opr/dnn/batch_norm.h @@ -79,6 +79,8 @@ MGB_DEFINE_OPR_CLASS(BatchNormForward, // if set to True, running mean/variance will be updated inplace bool m_force_inplace = true; + // need running mean/variance + bool need_stats() const {return input().size() == 5 && param().fwd_mode == Param::FwdMode::TRAINING;} }; using BatchNorm = BatchNormForward;