From 76f4f9753667e7c6ad5aac8553739c16489c01f6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 May 2021 16:52:20 +0800 Subject: [PATCH] refactor(sublinear): add SeqModifierBase GitOrigin-RevId: 2d0393be6b950690c5960ac63bd47931a3afb324 --- src/core/impl/graph/seq_modifier_base.cpp | 161 ++++++++++ src/core/impl/graph/seq_modifier_base.h | 237 +++++++++++++++ src/core/impl/graph/seq_sublinear_memory.cpp | 290 ++----------------- src/core/impl/graph/seq_sublinear_memory.h | 73 +---- 4 files changed, 438 insertions(+), 323 deletions(-) create mode 100644 src/core/impl/graph/seq_modifier_base.cpp create mode 100644 src/core/impl/graph/seq_modifier_base.h diff --git a/src/core/impl/graph/seq_modifier_base.cpp b/src/core/impl/graph/seq_modifier_base.cpp new file mode 100644 index 000000000..0eaa9d328 --- /dev/null +++ b/src/core/impl/graph/seq_modifier_base.cpp @@ -0,0 +1,161 @@ +/** + * \file src/core/impl/graph/seq_modifier_base.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./seq_modifier_base.h" + +#if MGB_ENABLE_SUBLINEAR + +using namespace mgb; +using namespace cg; + +void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_seq) { + m_orig_opr_seq = &opr_seq; + + m_var_storage.clear(); + m_seq.clear(); + m_var_mempool.reorder_free(); + m_opr_mempool.reorder_free(); + m_nr_endpoint_oprs = 0; + + ThinHashMap varmap; + for (auto orig_opr : *m_orig_opr_seq) { + auto time = m_seq.size(); + m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); + auto opr = m_seq.back().get(); + m_nr_endpoint_oprs += opr->is_endpoint; + for (auto&& dep : orig_opr->node_prop().dep_map()) { + if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) + continue; + + auto iter = varmap.find(dep.first); + if (iter == varmap.end()) { + // input var needs not to be considered + continue; + } + + auto ivar = iter->second; + bool exist = false; + for (auto i : opr->input) { + if (i == ivar) { + exist = true; + break; + } + } + if (exist) { + // same var for different inputs + continue; + } + + opr->input.push_back(ivar); + auto&& prev_rec = ivar->access_rec.back(); + prev_rec.stride = time - prev_rec.opr->time; + ivar->access_rec.emplace_back(opr); + } + + for (auto i : orig_opr->output()) { + auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); + auto iter = var2memsize->find(i); + if (iter == var2memsize->end()) { + // some vars are ignored; see split_into_cn2oprseq() + continue; + } + m_var_storage.emplace_back( + m_var_mempool.alloc_unique(i, iter->second, opr)); + auto ovar = m_var_storage.back().get(); + varmap[i] = ovar; + opr->output.push_back(ovar); + } + mgb_assert(!opr->output.empty()); + } + + // remove unused output + for (auto&& i : m_seq) { + auto&& oarr = i->output; + for (size_t j = 0; j < oarr.size();) { + if (oarr[j]->access_rec.size() == 1) { + std::swap(oarr[j], oarr.back()); + oarr.pop_back(); + } else + ++j; + } + } +} + +bool SeqModifierBase::replace_vars(const VarNodeArray& inputs) { + m_new_inputs.assign(inputs.begin(), inputs.end()); + bool changed = false; + for (auto&& i : m_new_inputs) { + auto iter = m_var_map.find(i); + if (iter != m_var_map.end()) { + i = iter->second; + changed = true; + } + } + return changed; +} + +OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs( + OperatorNodeBase* opr, bool recomp, size_t recomp_cnt) { + auto config = opr->config(); + // update operator instance id to bybass the shallow copy's cache if + // it's a dup-opr-copying due to discarding. + // Don't update instance id by `this` pointer if it's a recomp-opr-copying + // because: + // 0) recomp-opr would be copied iff its input vars is changed + // 1) some pair of recomp-opr and dup-opr have the same inputs, params + // and config, we use instance id to differentiate them. + config.name(opr->name() + (recomp ? ":recomp" : ":dup") + std::to_string(recomp_cnt)); + config.update_instance_id(reinterpret_cast( + reinterpret_cast(this) + + ((static_cast(recomp) + 1) << 10) * recomp_cnt)); + + // Note: if all outputs of op were placed on the same comp_node, since its + // stream maybe changed during seq_comp_node_opt, output's comp_node has + // higher priority than opr->config() + auto out_cn = opr->output(0)->comp_node(); + for (auto i : opr->output()) { + auto cn = i->comp_node(); + if (out_cn != cn) { + out_cn = {}; + break; + } + } + if (out_cn.valid()) + config.comp_node(out_cn); + + auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); + mgb_assert(opr_new != opr); + + auto&& out0 = opr->output(); + auto&& out1 = opr_new->output(); + mgb_assert(out0.size() == out1.size()); + bool stream_changed = false; + for (size_t i = 0; i < out0.size(); ++i) { + auto &&cn0 = out0[i]->comp_node(), + &&cn1 = out1[i]->comp_node(); + if (cn0 != cn1) { + mgb_assert(recomp); + mgb_assert(cn0.locator().type == cn1.locator().type && + cn0.locator().device == cn1.locator().device); + out1[i]->comp_node(cn0); + stream_changed = true; + } + m_var_map[out0[i]] = out1[i]; + } + if (stream_changed) { + opr_new->on_output_comp_node_stream_changed(); + } + return opr_new; +} + +#endif // MGB_ENABLE_SUBLINEAR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/src/core/impl/graph/seq_modifier_base.h b/src/core/impl/graph/seq_modifier_base.h new file mode 100644 index 000000000..6ce74c805 --- /dev/null +++ b/src/core/impl/graph/seq_modifier_base.h @@ -0,0 +1,237 @@ +/** + * \file src/core/impl/graph/seq_modifier_base.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "./memory_optimizer.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/graph/cg.h" +#include "megbrain/plugin/opr_footprint.h" +#include "megbrain/serialization/opr_shallow_copy.h" +#include "megbrain/system.h" +#include "megbrain/utils/async_worker.h" +#include "megbrain/utils/arith_helper.h" +#include "megbrain/utils/mempool.h" +#include "megbrain/utils/timer.h" + +#if MGB_ENABLE_SUBLINEAR +namespace mgb { +namespace cg { + +/*! + * \brief modifying computing sequence, with basically the same idea of Training + * Deep Nets with Sublinear Memory Cost and Dynamic Tensor Rematerialization + */ +class SeqModifierBase { +public: + /*! + * describes modifications that should be applied to an operator sequnce: + * maps from an opr to the oprs that should be duplicated and inserted + * before it. + */ + using SeqModifyAction = std::unordered_map; + + struct Var; + struct Opr; + + class ModifyActionPlannerBase { + const SeqModifierBase* const m_par_modifier; + const OprNodeArray* m_orig_opr_seq; + + MemPool m_var_mempool; + MemPool m_opr_mempool; + std::vector::UniquePtr> m_var_storage; + std::vector::UniquePtr> m_seq; + size_t m_nr_endpoint_oprs = 0; + + public: + //! special creation time used for oprs duplicated from others + static constexpr size_t DUPOPR_TIME = + std::numeric_limits::max() - 1; + + const SeqModifierBase* const par_modifier() { + return m_par_modifier; + } + + const OprNodeArray* const orig_opr_seq() { + return m_orig_opr_seq; + } + + MemPool& var_mempool() { + return m_var_mempool; + } + + MemPool& opr_mempool() { + return m_opr_mempool; + } + + std::vector::UniquePtr>& var_storage() { + return m_var_storage; + } + + std::vector::UniquePtr>& seq() { + return m_seq; + } + + size_t& nr_endpoint_oprs() { + return m_nr_endpoint_oprs; + } + + ModifyActionPlannerBase(SeqModifierBase* par) + : m_par_modifier{par} {} + + ~ModifyActionPlannerBase() noexcept { + m_opr_mempool.disable_freelist(); + m_var_mempool.disable_freelist(); + } + + //! init m_orig_opr_seq from opr_seq, should be called first. + void init_seq(const OprNodeArray& opr_seq); + }; + + SeqModifierBase(ComputingGraphImpl* owner) : m_mem_opt(owner), m_owner_graph(owner) {} + + MemoryOptimizerHelper& mem_opt() { + return m_mem_opt; + } + + ComputingGraphImpl* const owner_graph() { + return m_owner_graph; + } + + ThinHashMap& var_map() { + return m_var_map; + } + + /*! + * \brief copy opr and set inputs to m_new_inputs, and add outputs in + * m_var_map + * \return new operator + */ + OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, bool recomp, size_t recomp_cnt=0); + + /*! + * \brief replace input vars according to m_var_map, and store results in + * m_new_inputs; + * \return whether any var is changed + */ + bool replace_vars(const VarNodeArray& inputs); + + //! see memory_optimizer set_priority_before_opt + void set_priority_before_opt(const VarNodeArray& endpoints) { + m_mem_opt.set_priority_before_opt(endpoints); + } + + //! see memory_optimizer restore_graph_option + void restore_graph_option() { + m_mem_opt.restore_graph_option(); + } + +private: + MemoryOptimizerHelper m_mem_opt; + + ComputingGraphImpl* const m_owner_graph = nullptr; + + //! map from original var to replaced var + ThinHashMap m_var_map; + + VarNodeArray m_new_inputs; //!< setup by replace_vars +}; + +struct SeqModifierBase::Opr { + OperatorNodeBase* const orig_opr; + std::vector input, output; + const size_t time; //!< index in opr sequence + const bool is_endpoint; + double estimate_compute_time = 1; + + //! input vars that have been discarded and need to be recomputed before + //! this opr; for internal use by apply_discard_plan() + std::vector inputs_to_recompute; + + //! new oprs to be inserted before this opr; setup by apply_discard_plan() + std::vector::UniquePtr> oprs_insert_before; + + //! [begin, end) interval of *time* for oprs belonging to this block; setup + //! by make_discard_plan() + size_t block_begin_time = 0, block_end_time = 0; + + Opr(OperatorNodeBase* opr, size_t t) + : orig_opr{opr}, + time{t}, + is_endpoint{opr->owner_graph() + ->options() + .opr_attribute.get_sublinear_memory_endpoint( + opr)} {} +}; + +struct SeqModifierBase::Var { + VarNode* const orig_var; + size_t size; //!< memory usage in bytes of this var + size_t recomp_id = 0; + double last_access_time = 0; + + //! write or read access of a var + struct AccessRecord { + Opr* const opr; + const size_t time; + size_t stride; + + explicit AccessRecord(Opr* o = nullptr) + : opr{o}, time{o->time}, stride{0} {} + }; + + //! access_rec[0] is the creation opr, and others are reader oprs + std::vector access_rec; + + /*! + * An index in access_rec + * + * if valid, then the var should be discarded after + * discard_tailing_access->opr finishes + * + * setup by make_discard_plan + */ + Maybe discard_tailing_access; + + /*! + * An index in access_rec + * maintained during make_discard_plan(), for the next access relative to + * current operator + */ + Maybe next_access; + + AccessRecord* visit_discard_tailing_access() { + return discard_tailing_access.valid() + ? &access_rec.at(discard_tailing_access.val()) + : nullptr; + } + + AccessRecord* visit_next_access() { + return next_access.valid() ? &access_rec.at(next_access.val()) + : nullptr; + } + + auto owner_opr() const { return access_rec[0].opr; } + + auto last_access_opr() const { return access_rec.back().opr; } + + Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { + access_rec.emplace_back(opr); + } +}; + +} // namespace cg +} // namespace mgb + +#endif // MGB_ENABLE_SUBLINEAR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/src/core/impl/graph/seq_sublinear_memory.cpp b/src/core/impl/graph/seq_sublinear_memory.cpp index 312d47c5f..eaac0aa81 100644 --- a/src/core/impl/graph/seq_sublinear_memory.cpp +++ b/src/core/impl/graph/seq_sublinear_memory.cpp @@ -61,108 +61,15 @@ bool is_bad_opr(OperatorNodeBase* opr) { } } // namespace -/* ====================== Abstract Opr & Var ====================== */ -struct SeqModifierForSublinearMemory::Opr { - OperatorNodeBase* const orig_opr; - std::vector input, output; - const size_t time; //!< index in opr sequence - const bool is_endpoint; - - //! input vars that have been discarded and need to be recomputed before - //! this opr; for internal use by apply_discard_plan() - std::vector inputs_to_recompute; - - //! new oprs to be inserted before this opr; setup by apply_discard_plan() - std::vector::UniquePtr> oprs_insert_before; - - //! [begin, end) interval of *time* for oprs belonging to this block; setup - //! by make_discard_plan() - size_t block_begin_time = 0, block_end_time = 0; - - Opr(OperatorNodeBase* opr, size_t t) - : orig_opr{opr}, - time{t}, - is_endpoint{opr->owner_graph() - ->options() - .opr_attribute.get_sublinear_memory_endpoint( - opr)} {} -}; - -struct SeqModifierForSublinearMemory::Var { - //! write or read access of a var - struct AccessRecord { - Opr* const opr; - const size_t time; - size_t stride; //!< time distance until next read; 0 for last access - - explicit AccessRecord(Opr* o = nullptr) - : opr{o}, time{o->time}, stride{0} {} - }; - - VarNode* const orig_var; - const size_t size; //!< memory usage in bytes of this var - - //! access_rec[0] is the creation opr, and others are reader oprs - std::vector access_rec; - - /*! - * An index in access_rec - * - * if valid, then the var should be discarded after - * discard_tailing_access->opr finishes - * - * setup by make_discard_plan - */ - Maybe discard_tailing_access; - - /*! - * An index in access_rec - * maintained during make_discard_plan(), for the next access relative to - * current operator - */ - Maybe next_access; - - AccessRecord* visit_discard_tailing_access() { - return discard_tailing_access.valid() - ? &access_rec.at(discard_tailing_access.val()) - : nullptr; - } - - AccessRecord* visit_next_access() { - return next_access.valid() ? &access_rec.at(next_access.val()) - : nullptr; - } - - auto owner_opr() const { return access_rec[0].opr; } - - auto last_access_opr() const { return access_rec.back().opr; } - - Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} { - access_rec.emplace_back(opr); - } -}; /* ====================== ModifyActionPlanner ====================== */ -class SeqModifierForSublinearMemory::ModifyActionPlanner { - //! special creation time used for oprs duplicated from others - static constexpr size_t DUPOPR_TIME = - std::numeric_limits::max() - 1; - +class SeqModifierForSublinearMemory::ModifyActionPlanner : public ModifyActionPlannerBase { using VarArray = std::vector; using VarSet = ThinHashSet; using OprArray = std::vector; - const SeqModifierForSublinearMemory* const m_par_modifier; - const OprNodeArray* m_orig_opr_seq; - - MemPool m_var_mempool; - MemPool m_opr_mempool; - std::vector::UniquePtr> m_var_storage; - std::vector::UniquePtr> m_seq; - - size_t m_nr_endpoint_oprs = 0; - VarSet m_prev_block_discard_vars; std::vector m_blocks; + SeqModifyAction m_action; //! split_point_set to block void split_into_blocks(const SplitPointSet& split_point_set); @@ -188,14 +95,7 @@ class SeqModifierForSublinearMemory::ModifyActionPlanner { public: ModifyActionPlanner(SeqModifierForSublinearMemory* par) - : m_par_modifier{par} {} - - ~ModifyActionPlanner() noexcept { - m_opr_mempool.disable_freelist(); - m_var_mempool.disable_freelist(); - } - //! init m_orig_opr_seq from opr_seq, should be called first. - void init_seq(const OprNodeArray& opr_seq); + : ModifyActionPlannerBase{par} {} //! generate split point set from thresh SplitPointSet get_split_point_set(size_t block_size_thresh); @@ -213,7 +113,7 @@ public: void SeqModifierForSublinearMemory::ModifyActionPlanner::get_prev_action( SeqModifyAction& action) { action.clear(); - for (auto&& opr : m_seq) { + for (auto&& opr : seq()) { auto&& arr = opr->oprs_insert_before; if (arr.empty()) continue; @@ -261,8 +161,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( cur_block_alive_vars.clear(); }; - for (size_t i = 0; i < m_seq.size(); ++i) { - auto opr = m_seq[i].get(); + for (size_t i = 0; i < seq().size(); ++i) { + auto opr = seq()[i].get(); for (auto i : opr->output) add_alive(i); @@ -272,8 +172,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( remove_alive(i); } - if (i + 1 < m_seq.size() && (cur_block_usage < block_size_thresh || - (m_nr_endpoint_oprs && !opr->is_endpoint))) + if (i + 1 < seq().size() && (cur_block_usage < block_size_thresh || + (nr_endpoint_oprs() && !opr->is_endpoint))) continue; flush_block_member(i); @@ -281,81 +181,6 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set( return split_point_set; } -void SeqModifierForSublinearMemory::ModifyActionPlanner::init_seq( - const OprNodeArray& opr_seq) { - m_orig_opr_seq = &opr_seq; - - m_var_storage.clear(); - m_seq.clear(); - m_var_mempool.reorder_free(); - m_opr_mempool.reorder_free(); - m_nr_endpoint_oprs = 0; - - ThinHashMap varmap; - for (auto orig_opr : *m_orig_opr_seq) { - auto time = m_seq.size(); - m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); - auto opr = m_seq.back().get(); - m_nr_endpoint_oprs += opr->is_endpoint; - - for (auto&& dep : orig_opr->node_prop().dep_map()) { - if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second)) - continue; - - auto iter = varmap.find(dep.first); - if (iter == varmap.end()) { - // input var needs not to be considered - continue; - } - - auto ivar = iter->second; - bool exist = false; - for (auto i : opr->input) { - if (i == ivar) { - exist = true; - break; - } - } - if (exist) { - // same var for different inputs - continue; - } - - opr->input.push_back(ivar); - auto&& prev_rec = ivar->access_rec.back(); - prev_rec.stride = time - prev_rec.opr->time; - ivar->access_rec.emplace_back(opr); - } - - for (auto i : orig_opr->output()) { - auto var2memsize = m_par_modifier->m_mem_opt.var2memsize(); - auto iter = var2memsize->find(i); - if (iter == var2memsize->end()) { - // some vars are ignored; see split_into_cn2oprseq() - continue; - } - m_var_storage.emplace_back( - m_var_mempool.alloc_unique(i, iter->second, opr)); - auto ovar = m_var_storage.back().get(); - varmap[i] = ovar; - opr->output.push_back(ovar); - } - mgb_assert(!opr->output.empty()); - } - - // remove unused output - for (auto&& i : m_seq) { - auto&& oarr = i->output; - for (size_t j = 0; j < oarr.size();) { - if (oarr[j]->access_rec.size() == 1) { - std::swap(oarr[j], oarr.back()); - oarr.pop_back(); - } else - ++j; - } - } -} - size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: calc_bottleneck_from_discard_plan() { size_t cur_usage = 0, max_usage = 0; @@ -394,7 +219,7 @@ size_t SeqModifierForSublinearMemory::ModifyActionPlanner:: ++time; }; - for (auto&& opr : m_seq) { + for (auto&& opr : seq()) { for (auto&& i : opr->oprs_insert_before) process_opr(i.get()); process_opr(opr.get()); @@ -480,7 +305,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { mgb_assert(opr->time < block_end); - auto new_opr_storage = m_opr_mempool.alloc_unique( + auto new_opr_storage = opr_mempool().alloc_unique( opr->orig_opr, static_cast(DUPOPR_TIME)); auto new_opr = new_opr_storage.get(); @@ -497,7 +322,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { Var* new_var = nullptr; for (auto i : opr->output) { - auto&& ovar = m_var_mempool.alloc_unique(i->orig_var, i->size, + auto&& ovar = var_mempool().alloc_unique(i->orig_var, i->size, new_opr); new_opr->output.push_back(ovar.get()); if (i == var) @@ -507,7 +332,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { auto ins = var_map.insert({i, ovar.get()}); mgb_assert(ins.second); - m_var_storage.emplace_back(std::move(ovar)); + var_storage().emplace_back(std::move(ovar)); } mgb_assert(new_var); return new_var; @@ -515,7 +340,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() { add_dep(var); }; - for (auto&& _raw_opr : m_seq) { + for (auto&& _raw_opr : seq()) { auto opr = _raw_opr.get(); for (auto i : opr->inputs_to_recompute) @@ -640,8 +465,8 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( m_blocks.clear(); std::vector cur_block_member; size_t i, j; - for (i = j = 0; i < m_seq.size() && j < split_point_set->size(); ++i) { - auto opr = m_seq[i].get(); + for (i = j = 0; i < seq().size() && j < split_point_set->size(); ++i) { + auto opr = seq()[i].get(); cur_block_member.push_back(opr); if (i != split_point_set->at(j)) continue; @@ -649,7 +474,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks( cur_block_member.clear(); j++; } - mgb_assert(i >= m_seq.size()); + mgb_assert(i >= seq().size()); mgb_assert(j >= split_point_set->size()); } @@ -1081,7 +906,7 @@ void SeqModifierForSublinearMemory::InternalDeleter::operator()( } void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { - m_var_map.clear(); + var_map().clear(); m_opr2replace_info.clear(); auto config = MemoryOptimizerHelper::SubGraphConfig() @@ -1099,7 +924,7 @@ void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) { .add_bad_var_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) .add_bad_var_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE); - auto cn2oprseq = m_mem_opt.split_into_cn2oprseq(oprseq, config); + auto cn2oprseq = mem_opt().split_into_cn2oprseq(oprseq, config); if (cn2oprseq->empty()) { // empty graph @@ -1175,7 +1000,7 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, // each operator should be set no more than once auto set_priority = [&](OperatorNodeBase* opr) { mgb_assert(modified_opr.insert(opr).second); - m_mem_opt.set_priority(opr, cur_priority++); + mem_opt().set_priority(opr, cur_priority++); }; auto on_opr_visited = [&](OperatorNodeBase* opr) { @@ -1218,80 +1043,13 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action, mgb_assert(action.empty()); } -bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) { - m_new_inputs.assign(inputs.begin(), inputs.end()); - bool changed = false; - for (auto&& i : m_new_inputs) { - auto iter = m_var_map.find(i); - if (iter != m_var_map.end()) { - i = iter->second; - changed = true; - } - } - return changed; -} - -OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs( - OperatorNodeBase* opr, bool recomp) { - auto config = opr->config(); - // update operator instance id to bybass the shallow copy's cache if - // it's a dup-opr-copying due to discarding. - // Don't update instance id by `this` pointer if it's a recomp-opr-copying - // because: - // 0) recomp-opr would be copied iff its input vars is changed - // 1) some pair of recomp-opr and dup-opr have the same inputs, params - // and config, we use instance id to differentiate them. - config.name(opr->name() + (recomp ? ":recomp" : ":dup")); - if (!recomp) { - config.update_instance_id(this); - } - - // Note: if all outputs of op were placed on the same comp_node, since its - // stream maybe changed during seq_comp_node_opt, output's comp_node has - // higher priority than opr->config() - auto out_cn = opr->output(0)->comp_node(); - for (auto i : opr->output()) { - auto cn = i->comp_node(); - if (out_cn != cn) { - out_cn = {}; - break; - } - } - if (out_cn.valid()) - config.comp_node(out_cn); - - auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config); - mgb_assert(opr_new != opr); - - auto&& out0 = opr->output(); - auto&& out1 = opr_new->output(); - mgb_assert(out0.size() == out1.size()); - bool stream_changed = false; - for (size_t i = 0; i < out0.size(); ++i) { - auto &&cn0 = out0[i]->comp_node(), - &&cn1 = out1[i]->comp_node(); - if (cn0 != cn1) { - mgb_assert(recomp); - mgb_assert(cn0.locator().type == cn1.locator().type && - cn0.locator().device == cn1.locator().device); - out1[i]->comp_node(cn0); - stream_changed = true; - } - m_var_map[out0[i]] = out1[i]; - } - if (stream_changed) { - opr_new->on_output_comp_node_stream_changed(); - } - return opr_new; -} - void SeqModifierForSublinearMemory::modify_endpoint_vars( VarNodeArray& endpoints) { - auto comp_seq = MemoryOptimizerHelper::CompSeq(m_owner_graph, endpoints); + auto comp_seq = MemoryOptimizerHelper::CompSeq(owner_graph(), endpoints); reset_opr_seq(*comp_seq.m_seq); for (auto&& i : endpoints) { - auto iter = m_var_map.find(i); - if (iter != m_var_map.end()) { + auto iter = var_map().find(i); + if (iter != var_map().end()) { i = iter->second; } } @@ -1357,8 +1115,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( ComputingGraphImpl* owner, Config* config_p) - : m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} + : SeqModifierBase(owner), m_config(config_p) {} #endif // !MGB_ENABLE_SUBLINEAR -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/src/core/impl/graph/seq_sublinear_memory.h b/src/core/impl/graph/seq_sublinear_memory.h index 6b79b2f4c..13c8f5f66 100644 --- a/src/core/impl/graph/seq_sublinear_memory.h +++ b/src/core/impl/graph/seq_sublinear_memory.h @@ -12,6 +12,7 @@ #pragma once #include "./memory_optimizer.h" +#include "./seq_modifier_base.h" #include "megbrain/graph/cg.h" #include "megbrain/utils/async_worker.h" @@ -23,28 +24,31 @@ namespace cg { * \brief modifying computing sequence, with basically the same idea of Training * Deep Nets with Sublinear Memory Cost */ -class SeqModifierForSublinearMemory { - /*! - * describes modifications that should be applied to an operator sequnce: - * maps from an opr to the oprs that should be duplicated and inserted - * before it. - */ - using SeqModifyAction = std::unordered_map; - using SplitPointSet = std::shared_ptr>; - +class SeqModifierForSublinearMemory : public SeqModifierBase { //! Config options using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; Config* m_config; +public: + SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); + + //! replace endpoint vars by the ones that require more computing + void modify_endpoint_vars(VarNodeArray& endpoints); + + //! check whether actual opr_seq is what we expect; throw InternalError + void sanity_check(const OprNodeArray& opr_seq); + + const CompNode::UnorderedMap& prev_min_bottleneck(); + +private: + using SplitPointSet = std::shared_ptr>; + //! get modifications to be taken under some specific constraints class ModifyActionPlanner; //! search best modify action for opr seq on a single comp node class ActionSearcherSingleCN; - struct Opr; - struct Var; - struct InternalDeleter { void operator()(ActionSearcherSingleCN*) const; void operator()(ModifyActionPlanner*) const; @@ -67,32 +71,8 @@ class SeqModifierForSublinearMemory { //! thread pool to run ModifyActionPlanner FutureThreadPool m_planner_thread_pool; - //! map from original var to replaced var - ThinHashMap m_var_map; - - VarNodeArray m_new_inputs; //!< setup by replace_vars - - MemoryOptimizerHelper m_mem_opt; - - ComputingGraphImpl* const m_owner_graph = nullptr; - CompNode::UnorderedMap m_prev_min_bottleneck; - /*! - * \brief replace input vars according to m_var_map, and store results in - * m_new_inputs; - * \return whether any var is changed - */ - bool replace_vars(const VarNodeArray& inputs); - - /*! - * \brief copy opr and set inputs to m_new_inputs, and add outputs in - * m_var_map - * \return new operator - */ - OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, - bool recomp); - //! restore computing sequence and modify operator priority void reset_opr_seq(const OprNodeArray& oprseq); @@ -107,27 +87,6 @@ class SeqModifierForSublinearMemory { return std::make_shared( std::forward(args)...); } - -public: - SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); - - //! see memory_optimizer set_priority_before_opt - void set_priority_before_opt(const VarNodeArray& endpoints) { - m_mem_opt.set_priority_before_opt(endpoints); - } - - //! see memory_optimizer restore_graph_option - void restore_graph_option() { - m_mem_opt.restore_graph_option(); - } - - //! replace endpoint vars by the ones that require more computing - void modify_endpoint_vars(VarNodeArray& endpoints); - - //! check whether actual opr_seq is what we expect; throw InternalError - void sanity_check(const OprNodeArray& opr_seq); - - const CompNode::UnorderedMap& prev_min_bottleneck(); }; } // namespace cg -- GitLab