diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 64315ab0aabd0f0b3e0352ec5c14a2a4bad1acda..6f5e9e81bce3c0feb8210f11774a48d1d422a7df 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -74,7 +74,7 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( auto &&ins = m_varmap.insert({out0[i], {true, nullptr}}); mgb_assert(ins.second || ins.first->second.first, - "opr output already replaced"); + "opr output already replaced"); // handle repeated call on the same opr ins.first->second.second = out1[i]; on_var_replaced(out0[i], out1[i], nullptr); @@ -771,7 +771,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( /* ================ ConstVarPropogateBase ================ */ -ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( +ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( OperatorNodeBase *opr) { using ProfFlag = OperatorNodeBase::NodeProp::Flag; auto &&info = m_oprinfo[opr]; @@ -834,7 +834,6 @@ ConstVarPropogateBase::AddOprResult ConstVarPropogateBase::add_opr( #endif info.max_size = max_input_size; info.is_const = true; - on_midconst_opr(opr, max_input_size); } return make_ret(); } diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index b569b2fd4cd51bc6154609bbd90a18ba725b6981..c60f876b11e3d822ca1ef31c70bf8620d4ab158e 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -442,50 +442,6 @@ void ParamRedistributePass::apply(OptState &state) const { /* ================ ParamFusePass ================ */ -class ParamFusePass::ConstVarPropogateWithSizeCheck final: - public ConstVarPropogateBase -{ - public: - //! rewrite a var; reader == nullptr means needed by endpoint - using VarRewriter = std::function< - void(VarNode *var, OperatorNodeBase *reader)>; - - ConstVarPropogateWithSizeCheck( - const ParamFusePass &pf, OptState &opt_state, - const VarRewriter &rewriter): - ConstVarPropogateBase{ConstVarType::IMMUTABLE_AND_PARAM}, - m_owner{pf}, m_opt_state{opt_state}, m_rewriter{rewriter} - { - } - - private: - - const ParamFusePass &m_owner; - OptState &m_opt_state; - VarRewriter m_rewriter; - - void on_midconst_opr( - OperatorNodeBase *opr, size_t max_src_size) override { - for (auto var: opr->output()) { - if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) - continue; - - auto osize = var_mem_size(var); - if (osize >= max_src_size && - osize - max_src_size > m_owner.m_param_grow_limit) { - return; - } - - // const oprs should be evaluated when output is used by another - // non-const opr or output is needed by the user - if (m_opt_state.graph().endpoint_contain(var)) { - m_rewriter(var, nullptr); - } - - } - } -}; - /*! * \brief get name for new param */ @@ -565,9 +521,15 @@ const char* ParamFusePass::name() const { void ParamFusePass::apply(OptState &state) const { auto rewriter = state.graph().make_rewriter(); auto cg = state.graph().comp_graph(); + + ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; + state.graph().iter([&cvprop](OperatorNodeBase *opr) { + cvprop.add_opr(opr); + }); + + ThinHashSet processed_var; VarNamer var_namer; - // reader: null if used as endvar auto replace_single_var = [&](VarNode *var, OperatorNodeBase *reader) { if (!processed_var.insert(var).second) @@ -619,9 +581,8 @@ void ParamFusePass::apply(OptState &state) const { rewriter.replace_var(var, new_var.node(), log.c_str()); }; - ConstVarPropogateWithSizeCheck cvprop{*this, state, replace_single_var}; - auto on_opr = [&](OperatorNodeBase *opr) { - auto add_ret = cvprop.add_opr(opr); + auto replace_opr = [&](OperatorNodeBase* opr) { + auto add_ret = cvprop.opr_rst(opr); if (!add_ret.all_const_inp && add_ret.has_midconst_inp) { for (auto i: opr->input()) { if (cvprop.is_midconst(i)) { @@ -631,9 +592,33 @@ void ParamFusePass::apply(OptState &state) const { } } rewriter.auto_replace_outputs(opr); + + //! we should deal with midconst var after auto_replace_outputs, as + //! on_midconst_opr will replace the endpoint output which may cause + //! double replace. + if (add_ret.all_const_inp) { + for (auto var : opr->output()) { + if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) + continue; + + auto osize = ConstVarPropogate::var_mem_size(var); + if (osize >= cvprop.max_size(opr) && + osize - cvprop.max_size(opr) > m_param_grow_limit) { + return; + } + + // const oprs should be evaluated when output is used by another + // non-const opr or output is needed by the user + if (state.graph().endpoint_contain(var)) { + replace_single_var(var, nullptr); + } + + } + + } }; - state.graph().iter(on_opr); + state.graph().iter(replace_opr); rewriter.apply_inplace(); } diff --git a/src/gopt/include/megbrain/gopt/framework.h b/src/gopt/include/megbrain/gopt/framework.h index 56d79f4c62ab6e70059d30e6bc3cdad35d4d58d0..091a2fe975fb7f77c5341855bdd1311a859e0f38 100644 --- a/src/gopt/include/megbrain/gopt/framework.h +++ b/src/gopt/include/megbrain/gopt/framework.h @@ -490,28 +490,17 @@ namespace gopt { * Usually you would want to use ConstVarPropogate, and this base class * exists to avoid virtual dtor while allowing polymorphism. */ - class ConstVarPropogateBase { - protected: - ~ConstVarPropogateBase() = default; - - //! memory usage of a var - static size_t var_mem_size(VarNode *var) { - return var->dtype().size(var->shape().total_nr_elems()); - } - - //! called after a const but non-source opr is visited - virtual void on_midconst_opr( - OperatorNodeBase *opr, size_t max_src_size) { - MGB_MARK_USED_VAR(opr); - MGB_MARK_USED_VAR(max_src_size); - } + class ConstVarPropogate{ public: - explicit ConstVarPropogateBase(ConstVarType const_var_type): + explicit ConstVarPropogate(ConstVarType const_var_type): m_const_var_type{const_var_type} { } + ConstVarPropogate() = default; + ~ConstVarPropogate() = default; + //! note that both attrs would be false if opr is impure or it is //! not allowed to be replaced struct AddOprResult { @@ -527,12 +516,19 @@ namespace gopt { AddOprResult add_opr(OperatorNodeBase *opr); + const AddOprResult& opr_rst(OperatorNodeBase *opr) const { + return m_oprinfo.at(opr).result; + } + bool is_const(OperatorNodeBase *opr) const { return m_oprinfo.at(opr).is_const; } bool is_const(VarNode *var) const { return is_const(var->owner_opr()); } + size_t max_size(OperatorNodeBase *opr) const { + return m_oprinfo.at(opr).max_size; + } //! whether a var is produced by non-source const opr bool is_midconst(OperatorNodeBase *opr) const { @@ -543,6 +539,11 @@ namespace gopt { return is_midconst(var->owner_opr()); } + //! memory usage of a var + static size_t var_mem_size(VarNode *var) { + return var->dtype().size(var->shape().total_nr_elems()); + } + private: struct OprInfo { bool processed = false, is_const = false; @@ -556,11 +557,6 @@ namespace gopt { }; - class ConstVarPropogate final: public ConstVarPropogateBase { - public: - using ConstVarPropogateBase::ConstVarPropogateBase; - }; - } // namespace gopt } // namespace mgb diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 03bf1bc1cad3389bf53435a51fedb72f68d58d8c..8321fd4de1fa5458d30875b1831c74483da1f6ba 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -112,6 +112,52 @@ void warp_perspective_mat_gen(HostTensorND& mat, size_t N, size_t INP_H, #endif } // namespace +TEST(TestGoptInference, ParamFuseConstEndPoint) { + constexpr size_t SIZE = 23; + HostTensorGenerator<> gen; + auto host_x = gen({SIZE}), host_y = gen({1}), host_p = gen({1}); + + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto x = opr::SharedDeviceTensor::make(*graph, *host_x), + y = opr::SharedDeviceTensor::make(*graph, *host_y), + p = opr::Host2DeviceCopy::make(*graph, host_p), + q = p + x, + a = y + 3, + z0 = a + q, + z1 = a + 4; + + HostTensorND host_z0, host_z1; + + SymbolVar z0_1, z1_1; + unpack_vector( + gopt::GraphOptimizer{}. + add_pass(). + apply({{z1, z0}}).endpoint_vars(), + z1_1, z0_1); + + auto func = graph->compile({make_callback_copy(z0_1, host_z0), + make_callback_copy(z1_1, host_z1)}); + func->to_json()->writeto_fpath( + output_file("TestGoptInference.ParamFuseEndPoint.json")); + func->execute(); + + int nr_opr = 0; + func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); + ASSERT_EQ(8, nr_opr); + + auto px = host_x->ptr(), pz0 = host_z0.ptr(); + + auto yv = host_y->ptr()[0], pv = host_p->ptr()[0], + pz1 = host_z1.ptr()[0]; + + for (size_t i = 0; i < SIZE; ++ i) { + MGB_ASSERT_FLOAT_EQ(px[i] + yv + 3 + pv, pz0[i]); + } + MGB_ASSERT_FLOAT_EQ(yv + 7, pz1); +} + + TEST(TestGoptInference, ParamFuse) { constexpr size_t SIZE = 23; HostTensorGenerator<> gen; @@ -144,7 +190,7 @@ TEST(TestGoptInference, ParamFuse) { func->execute(); int nr_opr = 0; - func->iter_opr_seq([&](cg::OperatorNodeBase*op) {++ nr_opr; return true; }); + func->iter_opr_seq([&](cg::OperatorNodeBase*) {++ nr_opr; return true; }); ASSERT_EQ(6, nr_opr); auto px = host_x->ptr(), pz = host_z.ptr(),