提交 76f4f975 编写于 作者: M Megvii Engine Team

refactor(sublinear): add SeqModifierBase

GitOrigin-RevId: 2d0393be6b950690c5960ac63bd47931a3afb324
上级 f584416a
/**
* \file src/core/impl/graph/seq_modifier_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./seq_modifier_base.h"
#if MGB_ENABLE_SUBLINEAR
using namespace mgb;
using namespace cg;
void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_seq) {
m_orig_opr_seq = &opr_seq;
m_var_storage.clear();
m_seq.clear();
m_var_mempool.reorder_free();
m_opr_mempool.reorder_free();
m_nr_endpoint_oprs = 0;
ThinHashMap<VarNode*, Var*> varmap;
for (auto orig_opr : *m_orig_opr_seq) {
auto time = m_seq.size();
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time));
auto opr = m_seq.back().get();
m_nr_endpoint_oprs += opr->is_endpoint;
for (auto&& dep : orig_opr->node_prop().dep_map()) {
if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second))
continue;
auto iter = varmap.find(dep.first);
if (iter == varmap.end()) {
// input var needs not to be considered
continue;
}
auto ivar = iter->second;
bool exist = false;
for (auto i : opr->input) {
if (i == ivar) {
exist = true;
break;
}
}
if (exist) {
// same var for different inputs
continue;
}
opr->input.push_back(ivar);
auto&& prev_rec = ivar->access_rec.back();
prev_rec.stride = time - prev_rec.opr->time;
ivar->access_rec.emplace_back(opr);
}
for (auto i : orig_opr->output()) {
auto var2memsize = m_par_modifier->m_mem_opt.var2memsize();
auto iter = var2memsize->find(i);
if (iter == var2memsize->end()) {
// some vars are ignored; see split_into_cn2oprseq()
continue;
}
m_var_storage.emplace_back(
m_var_mempool.alloc_unique(i, iter->second, opr));
auto ovar = m_var_storage.back().get();
varmap[i] = ovar;
opr->output.push_back(ovar);
}
mgb_assert(!opr->output.empty());
}
// remove unused output
for (auto&& i : m_seq) {
auto&& oarr = i->output;
for (size_t j = 0; j < oarr.size();) {
if (oarr[j]->access_rec.size() == 1) {
std::swap(oarr[j], oarr.back());
oarr.pop_back();
} else
++j;
}
}
}
bool SeqModifierBase::replace_vars(const VarNodeArray& inputs) {
m_new_inputs.assign(inputs.begin(), inputs.end());
bool changed = false;
for (auto&& i : m_new_inputs) {
auto iter = m_var_map.find(i);
if (iter != m_var_map.end()) {
i = iter->second;
changed = true;
}
}
return changed;
}
OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs(
OperatorNodeBase* opr, bool recomp, size_t recomp_cnt) {
auto config = opr->config();
// update operator instance id to bybass the shallow copy's cache if
// it's a dup-opr-copying due to discarding.
// Don't update instance id by `this` pointer if it's a recomp-opr-copying
// because:
// 0) recomp-opr would be copied iff its input vars is changed
// 1) some pair of recomp-opr and dup-opr have the same inputs, params
// and config, we use instance id to differentiate them.
config.name(opr->name() + (recomp ? ":recomp" : ":dup") + std::to_string(recomp_cnt));
config.update_instance_id(reinterpret_cast<void*>(
reinterpret_cast<size_t>(this) +
((static_cast<size_t>(recomp) + 1) << 10) * recomp_cnt));
// Note: if all outputs of op were placed on the same comp_node, since its
// stream maybe changed during seq_comp_node_opt, output's comp_node has
// higher priority than opr->config()
auto out_cn = opr->output(0)->comp_node();
for (auto i : opr->output()) {
auto cn = i->comp_node();
if (out_cn != cn) {
out_cn = {};
break;
}
}
if (out_cn.valid())
config.comp_node(out_cn);
auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config);
mgb_assert(opr_new != opr);
auto&& out0 = opr->output();
auto&& out1 = opr_new->output();
mgb_assert(out0.size() == out1.size());
bool stream_changed = false;
for (size_t i = 0; i < out0.size(); ++i) {
auto &&cn0 = out0[i]->comp_node(),
&&cn1 = out1[i]->comp_node();
if (cn0 != cn1) {
mgb_assert(recomp);
mgb_assert(cn0.locator().type == cn1.locator().type &&
cn0.locator().device == cn1.locator().device);
out1[i]->comp_node(cn0);
stream_changed = true;
}
m_var_map[out0[i]] = out1[i];
}
if (stream_changed) {
opr_new->on_output_comp_node_stream_changed();
}
return opr_new;
}
#endif // MGB_ENABLE_SUBLINEAR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
/**
* \file src/core/impl/graph/seq_modifier_base.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./memory_optimizer.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/cg.h"
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/system.h"
#include "megbrain/utils/async_worker.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/timer.h"
#if MGB_ENABLE_SUBLINEAR
namespace mgb {
namespace cg {
/*!
* \brief modifying computing sequence, with basically the same idea of Training
* Deep Nets with Sublinear Memory Cost and Dynamic Tensor Rematerialization
*/
class SeqModifierBase {
public:
/*!
* describes modifications that should be applied to an operator sequnce:
* maps from an opr to the oprs that should be duplicated and inserted
* before it.
*/
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>;
struct Var;
struct Opr;
class ModifyActionPlannerBase {
const SeqModifierBase* const m_par_modifier;
const OprNodeArray* m_orig_opr_seq;
MemPool<Var> m_var_mempool;
MemPool<Opr> m_opr_mempool;
std::vector<MemPool<Var>::UniquePtr> m_var_storage;
std::vector<MemPool<Opr>::UniquePtr> m_seq;
size_t m_nr_endpoint_oprs = 0;
public:
//! special creation time used for oprs duplicated from others
static constexpr size_t DUPOPR_TIME =
std::numeric_limits<size_t>::max() - 1;
const SeqModifierBase* const par_modifier() {
return m_par_modifier;
}
const OprNodeArray* const orig_opr_seq() {
return m_orig_opr_seq;
}
MemPool<Var>& var_mempool() {
return m_var_mempool;
}
MemPool<Opr>& opr_mempool() {
return m_opr_mempool;
}
std::vector<MemPool<Var>::UniquePtr>& var_storage() {
return m_var_storage;
}
std::vector<MemPool<Opr>::UniquePtr>& seq() {
return m_seq;
}
size_t& nr_endpoint_oprs() {
return m_nr_endpoint_oprs;
}
ModifyActionPlannerBase(SeqModifierBase* par)
: m_par_modifier{par} {}
~ModifyActionPlannerBase() noexcept {
m_opr_mempool.disable_freelist();
m_var_mempool.disable_freelist();
}
//! init m_orig_opr_seq from opr_seq, should be called first.
void init_seq(const OprNodeArray& opr_seq);
};
SeqModifierBase(ComputingGraphImpl* owner) : m_mem_opt(owner), m_owner_graph(owner) {}
MemoryOptimizerHelper& mem_opt() {
return m_mem_opt;
}
ComputingGraphImpl* const owner_graph() {
return m_owner_graph;
}
ThinHashMap<VarNode*, VarNode*>& var_map() {
return m_var_map;
}
/*!
* \brief copy opr and set inputs to m_new_inputs, and add outputs in
* m_var_map
* \return new operator
*/
OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr, bool recomp, size_t recomp_cnt=0);
/*!
* \brief replace input vars according to m_var_map, and store results in
* m_new_inputs;
* \return whether any var is changed
*/
bool replace_vars(const VarNodeArray& inputs);
//! see memory_optimizer set_priority_before_opt
void set_priority_before_opt(const VarNodeArray& endpoints) {
m_mem_opt.set_priority_before_opt(endpoints);
}
//! see memory_optimizer restore_graph_option
void restore_graph_option() {
m_mem_opt.restore_graph_option();
}
private:
MemoryOptimizerHelper m_mem_opt;
ComputingGraphImpl* const m_owner_graph = nullptr;
//! map from original var to replaced var
ThinHashMap<VarNode*, VarNode*> m_var_map;
VarNodeArray m_new_inputs; //!< setup by replace_vars
};
struct SeqModifierBase::Opr {
OperatorNodeBase* const orig_opr;
std::vector<Var*> input, output;
const size_t time; //!< index in opr sequence
const bool is_endpoint;
double estimate_compute_time = 1;
//! input vars that have been discarded and need to be recomputed before
//! this opr; for internal use by apply_discard_plan()
std::vector<Var*> inputs_to_recompute;
//! new oprs to be inserted before this opr; setup by apply_discard_plan()
std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before;
//! [begin, end) interval of *time* for oprs belonging to this block; setup
//! by make_discard_plan()
size_t block_begin_time = 0, block_end_time = 0;
Opr(OperatorNodeBase* opr, size_t t)
: orig_opr{opr},
time{t},
is_endpoint{opr->owner_graph()
->options()
.opr_attribute.get_sublinear_memory_endpoint(
opr)} {}
};
struct SeqModifierBase::Var {
VarNode* const orig_var;
size_t size; //!< memory usage in bytes of this var
size_t recomp_id = 0;
double last_access_time = 0;
//! write or read access of a var
struct AccessRecord {
Opr* const opr;
const size_t time;
size_t stride;
explicit AccessRecord(Opr* o = nullptr)
: opr{o}, time{o->time}, stride{0} {}
};
//! access_rec[0] is the creation opr, and others are reader oprs
std::vector<AccessRecord> access_rec;
/*!
* An index in access_rec
*
* if valid, then the var should be discarded after
* discard_tailing_access->opr finishes
*
* setup by make_discard_plan
*/
Maybe<size_t> discard_tailing_access;
/*!
* An index in access_rec
* maintained during make_discard_plan(), for the next access relative to
* current operator
*/
Maybe<size_t> next_access;
AccessRecord* visit_discard_tailing_access() {
return discard_tailing_access.valid()
? &access_rec.at(discard_tailing_access.val())
: nullptr;
}
AccessRecord* visit_next_access() {
return next_access.valid() ? &access_rec.at(next_access.val())
: nullptr;
}
auto owner_opr() const { return access_rec[0].opr; }
auto last_access_opr() const { return access_rec.back().opr; }
Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} {
access_rec.emplace_back(opr);
}
};
} // namespace cg
} // namespace mgb
#endif // MGB_ENABLE_SUBLINEAR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
......@@ -61,108 +61,15 @@ bool is_bad_opr(OperatorNodeBase* opr) {
}
} // namespace
/* ====================== Abstract Opr & Var ====================== */
struct SeqModifierForSublinearMemory::Opr {
OperatorNodeBase* const orig_opr;
std::vector<Var*> input, output;
const size_t time; //!< index in opr sequence
const bool is_endpoint;
//! input vars that have been discarded and need to be recomputed before
//! this opr; for internal use by apply_discard_plan()
std::vector<Var*> inputs_to_recompute;
//! new oprs to be inserted before this opr; setup by apply_discard_plan()
std::vector<MemPool<Opr>::UniquePtr> oprs_insert_before;
//! [begin, end) interval of *time* for oprs belonging to this block; setup
//! by make_discard_plan()
size_t block_begin_time = 0, block_end_time = 0;
Opr(OperatorNodeBase* opr, size_t t)
: orig_opr{opr},
time{t},
is_endpoint{opr->owner_graph()
->options()
.opr_attribute.get_sublinear_memory_endpoint(
opr)} {}
};
struct SeqModifierForSublinearMemory::Var {
//! write or read access of a var
struct AccessRecord {
Opr* const opr;
const size_t time;
size_t stride; //!< time distance until next read; 0 for last access
explicit AccessRecord(Opr* o = nullptr)
: opr{o}, time{o->time}, stride{0} {}
};
VarNode* const orig_var;
const size_t size; //!< memory usage in bytes of this var
//! access_rec[0] is the creation opr, and others are reader oprs
std::vector<AccessRecord> access_rec;
/*!
* An index in access_rec
*
* if valid, then the var should be discarded after
* discard_tailing_access->opr finishes
*
* setup by make_discard_plan
*/
Maybe<size_t> discard_tailing_access;
/*!
* An index in access_rec
* maintained during make_discard_plan(), for the next access relative to
* current operator
*/
Maybe<size_t> next_access;
AccessRecord* visit_discard_tailing_access() {
return discard_tailing_access.valid()
? &access_rec.at(discard_tailing_access.val())
: nullptr;
}
AccessRecord* visit_next_access() {
return next_access.valid() ? &access_rec.at(next_access.val())
: nullptr;
}
auto owner_opr() const { return access_rec[0].opr; }
auto last_access_opr() const { return access_rec.back().opr; }
Var(VarNode* var, size_t s, Opr* opr) : orig_var{var}, size{s} {
access_rec.emplace_back(opr);
}
};
/* ====================== ModifyActionPlanner ====================== */
class SeqModifierForSublinearMemory::ModifyActionPlanner {
//! special creation time used for oprs duplicated from others
static constexpr size_t DUPOPR_TIME =
std::numeric_limits<size_t>::max() - 1;
class SeqModifierForSublinearMemory::ModifyActionPlanner : public ModifyActionPlannerBase {
using VarArray = std::vector<Var*>;
using VarSet = ThinHashSet<Var*>;
using OprArray = std::vector<Opr*>;
const SeqModifierForSublinearMemory* const m_par_modifier;
const OprNodeArray* m_orig_opr_seq;
MemPool<Var> m_var_mempool;
MemPool<Opr> m_opr_mempool;
std::vector<MemPool<Var>::UniquePtr> m_var_storage;
std::vector<MemPool<Opr>::UniquePtr> m_seq;
size_t m_nr_endpoint_oprs = 0;
VarSet m_prev_block_discard_vars;
std::vector<OprArray> m_blocks;
SeqModifyAction m_action;
//! split_point_set to block
void split_into_blocks(const SplitPointSet& split_point_set);
......@@ -188,14 +95,7 @@ class SeqModifierForSublinearMemory::ModifyActionPlanner {
public:
ModifyActionPlanner(SeqModifierForSublinearMemory* par)
: m_par_modifier{par} {}
~ModifyActionPlanner() noexcept {
m_opr_mempool.disable_freelist();
m_var_mempool.disable_freelist();
}
//! init m_orig_opr_seq from opr_seq, should be called first.
void init_seq(const OprNodeArray& opr_seq);
: ModifyActionPlannerBase{par} {}
//! generate split point set from thresh
SplitPointSet get_split_point_set(size_t block_size_thresh);
......@@ -213,7 +113,7 @@ public:
void SeqModifierForSublinearMemory::ModifyActionPlanner::get_prev_action(
SeqModifyAction& action) {
action.clear();
for (auto&& opr : m_seq) {
for (auto&& opr : seq()) {
auto&& arr = opr->oprs_insert_before;
if (arr.empty())
continue;
......@@ -261,8 +161,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set(
cur_block_alive_vars.clear();
};
for (size_t i = 0; i < m_seq.size(); ++i) {
auto opr = m_seq[i].get();
for (size_t i = 0; i < seq().size(); ++i) {
auto opr = seq()[i].get();
for (auto i : opr->output)
add_alive(i);
......@@ -272,8 +172,8 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set(
remove_alive(i);
}
if (i + 1 < m_seq.size() && (cur_block_usage < block_size_thresh ||
(m_nr_endpoint_oprs && !opr->is_endpoint)))
if (i + 1 < seq().size() && (cur_block_usage < block_size_thresh ||
(nr_endpoint_oprs() && !opr->is_endpoint)))
continue;
flush_block_member(i);
......@@ -281,81 +181,6 @@ SeqModifierForSublinearMemory::ModifyActionPlanner::get_split_point_set(
return split_point_set;
}
void SeqModifierForSublinearMemory::ModifyActionPlanner::init_seq(
const OprNodeArray& opr_seq) {
m_orig_opr_seq = &opr_seq;
m_var_storage.clear();
m_seq.clear();
m_var_mempool.reorder_free();
m_opr_mempool.reorder_free();
m_nr_endpoint_oprs = 0;
ThinHashMap<VarNode*, Var*> varmap;
for (auto orig_opr : *m_orig_opr_seq) {
auto time = m_seq.size();
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time));
auto opr = m_seq.back().get();
m_nr_endpoint_oprs += opr->is_endpoint;
for (auto&& dep : orig_opr->node_prop().dep_map()) {
if (!OperatorNodeBase::NodeProp::is_device_value_dep(dep.second))
continue;
auto iter = varmap.find(dep.first);
if (iter == varmap.end()) {
// input var needs not to be considered
continue;
}
auto ivar = iter->second;
bool exist = false;
for (auto i : opr->input) {
if (i == ivar) {
exist = true;
break;
}
}
if (exist) {
// same var for different inputs
continue;
}
opr->input.push_back(ivar);
auto&& prev_rec = ivar->access_rec.back();
prev_rec.stride = time - prev_rec.opr->time;
ivar->access_rec.emplace_back(opr);
}
for (auto i : orig_opr->output()) {
auto var2memsize = m_par_modifier->m_mem_opt.var2memsize();
auto iter = var2memsize->find(i);
if (iter == var2memsize->end()) {
// some vars are ignored; see split_into_cn2oprseq()
continue;
}
m_var_storage.emplace_back(
m_var_mempool.alloc_unique(i, iter->second, opr));
auto ovar = m_var_storage.back().get();
varmap[i] = ovar;
opr->output.push_back(ovar);
}
mgb_assert(!opr->output.empty());
}
// remove unused output
for (auto&& i : m_seq) {
auto&& oarr = i->output;
for (size_t j = 0; j < oarr.size();) {
if (oarr[j]->access_rec.size() == 1) {
std::swap(oarr[j], oarr.back());
oarr.pop_back();
} else
++j;
}
}
}
size_t SeqModifierForSublinearMemory::ModifyActionPlanner::
calc_bottleneck_from_discard_plan() {
size_t cur_usage = 0, max_usage = 0;
......@@ -394,7 +219,7 @@ size_t SeqModifierForSublinearMemory::ModifyActionPlanner::
++time;
};
for (auto&& opr : m_seq) {
for (auto&& opr : seq()) {
for (auto&& i : opr->oprs_insert_before)
process_opr(i.get());
process_opr(opr.get());
......@@ -480,7 +305,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
mgb_assert(opr->time < block_end);
auto new_opr_storage = m_opr_mempool.alloc_unique(
auto new_opr_storage = opr_mempool().alloc_unique(
opr->orig_opr, static_cast<size_t>(DUPOPR_TIME));
auto new_opr = new_opr_storage.get();
......@@ -497,7 +322,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
Var* new_var = nullptr;
for (auto i : opr->output) {
auto&& ovar = m_var_mempool.alloc_unique(i->orig_var, i->size,
auto&& ovar = var_mempool().alloc_unique(i->orig_var, i->size,
new_opr);
new_opr->output.push_back(ovar.get());
if (i == var)
......@@ -507,7 +332,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
auto ins = var_map.insert({i, ovar.get()});
mgb_assert(ins.second);
m_var_storage.emplace_back(std::move(ovar));
var_storage().emplace_back(std::move(ovar));
}
mgb_assert(new_var);
return new_var;
......@@ -515,7 +340,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::apply_discard_plan() {
add_dep(var);
};
for (auto&& _raw_opr : m_seq) {
for (auto&& _raw_opr : seq()) {
auto opr = _raw_opr.get();
for (auto i : opr->inputs_to_recompute)
......@@ -640,8 +465,8 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks(
m_blocks.clear();
std::vector<Opr*> cur_block_member;
size_t i, j;
for (i = j = 0; i < m_seq.size() && j < split_point_set->size(); ++i) {
auto opr = m_seq[i].get();
for (i = j = 0; i < seq().size() && j < split_point_set->size(); ++i) {
auto opr = seq()[i].get();
cur_block_member.push_back(opr);
if (i != split_point_set->at(j))
continue;
......@@ -649,7 +474,7 @@ void SeqModifierForSublinearMemory::ModifyActionPlanner::split_into_blocks(
cur_block_member.clear();
j++;
}
mgb_assert(i >= m_seq.size());
mgb_assert(i >= seq().size());
mgb_assert(j >= split_point_set->size());
}
......@@ -1081,7 +906,7 @@ void SeqModifierForSublinearMemory::InternalDeleter::operator()(
}
void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) {
m_var_map.clear();
var_map().clear();
m_opr2replace_info.clear();
auto config =
MemoryOptimizerHelper::SubGraphConfig()
......@@ -1099,7 +924,7 @@ void SeqModifierForSublinearMemory::reset_opr_seq(const OprNodeArray& oprseq) {
.add_bad_var_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.add_bad_var_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE);
auto cn2oprseq = m_mem_opt.split_into_cn2oprseq(oprseq, config);
auto cn2oprseq = mem_opt().split_into_cn2oprseq(oprseq, config);
if (cn2oprseq->empty()) {
// empty graph
......@@ -1175,7 +1000,7 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action,
// each operator should be set no more than once
auto set_priority = [&](OperatorNodeBase* opr) {
mgb_assert(modified_opr.insert(opr).second);
m_mem_opt.set_priority(opr, cur_priority++);
mem_opt().set_priority(opr, cur_priority++);
};
auto on_opr_visited = [&](OperatorNodeBase* opr) {
......@@ -1218,80 +1043,13 @@ void SeqModifierForSublinearMemory::apply_action(SeqModifyAction& action,
mgb_assert(action.empty());
}
bool SeqModifierForSublinearMemory::replace_vars(const VarNodeArray& inputs) {
m_new_inputs.assign(inputs.begin(), inputs.end());
bool changed = false;
for (auto&& i : m_new_inputs) {
auto iter = m_var_map.find(i);
if (iter != m_var_map.end()) {
i = iter->second;
changed = true;
}
}
return changed;
}
OperatorNodeBase* SeqModifierForSublinearMemory::copy_opr_from_new_inputs(
OperatorNodeBase* opr, bool recomp) {
auto config = opr->config();
// update operator instance id to bybass the shallow copy's cache if
// it's a dup-opr-copying due to discarding.
// Don't update instance id by `this` pointer if it's a recomp-opr-copying
// because:
// 0) recomp-opr would be copied iff its input vars is changed
// 1) some pair of recomp-opr and dup-opr have the same inputs, params
// and config, we use instance id to differentiate them.
config.name(opr->name() + (recomp ? ":recomp" : ":dup"));
if (!recomp) {
config.update_instance_id(this);
}
// Note: if all outputs of op were placed on the same comp_node, since its
// stream maybe changed during seq_comp_node_opt, output's comp_node has
// higher priority than opr->config()
auto out_cn = opr->output(0)->comp_node();
for (auto i : opr->output()) {
auto cn = i->comp_node();
if (out_cn != cn) {
out_cn = {};
break;
}
}
if (out_cn.valid())
config.comp_node(out_cn);
auto opr_new = serialization::copy_opr_shallow(*opr, m_new_inputs, config);
mgb_assert(opr_new != opr);
auto&& out0 = opr->output();
auto&& out1 = opr_new->output();
mgb_assert(out0.size() == out1.size());
bool stream_changed = false;
for (size_t i = 0; i < out0.size(); ++i) {
auto &&cn0 = out0[i]->comp_node(),
&&cn1 = out1[i]->comp_node();
if (cn0 != cn1) {
mgb_assert(recomp);
mgb_assert(cn0.locator().type == cn1.locator().type &&
cn0.locator().device == cn1.locator().device);
out1[i]->comp_node(cn0);
stream_changed = true;
}
m_var_map[out0[i]] = out1[i];
}
if (stream_changed) {
opr_new->on_output_comp_node_stream_changed();
}
return opr_new;
}
void SeqModifierForSublinearMemory::modify_endpoint_vars(
VarNodeArray& endpoints) {
auto comp_seq = MemoryOptimizerHelper::CompSeq(m_owner_graph, endpoints);
auto comp_seq = MemoryOptimizerHelper::CompSeq(owner_graph(), endpoints);
reset_opr_seq(*comp_seq.m_seq);
for (auto&& i : endpoints) {
auto iter = m_var_map.find(i);
if (iter != m_var_map.end()) {
auto iter = var_map().find(i);
if (iter != var_map().end()) {
i = iter->second;
}
}
......@@ -1357,7 +1115,7 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() {
SeqModifierForSublinearMemory::SeqModifierForSublinearMemory(
ComputingGraphImpl* owner, Config* config_p)
: m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {}
: SeqModifierBase(owner), m_config(config_p) {}
#endif // !MGB_ENABLE_SUBLINEAR
......
......@@ -12,6 +12,7 @@
#pragma once
#include "./memory_optimizer.h"
#include "./seq_modifier_base.h"
#include "megbrain/graph/cg.h"
#include "megbrain/utils/async_worker.h"
......@@ -23,28 +24,31 @@ namespace cg {
* \brief modifying computing sequence, with basically the same idea of Training
* Deep Nets with Sublinear Memory Cost
*/
class SeqModifierForSublinearMemory {
/*!
* describes modifications that should be applied to an operator sequnce:
* maps from an opr to the oprs that should be duplicated and inserted
* before it.
*/
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>;
using SplitPointSet = std::shared_ptr<std::vector<size_t>>;
class SeqModifierForSublinearMemory : public SeqModifierBase {
//! Config options
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig;
Config* m_config;
public:
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g);
//! replace endpoint vars by the ones that require more computing
void modify_endpoint_vars(VarNodeArray& endpoints);
//! check whether actual opr_seq is what we expect; throw InternalError
void sanity_check(const OprNodeArray& opr_seq);
const CompNode::UnorderedMap<size_t>& prev_min_bottleneck();
private:
using SplitPointSet = std::shared_ptr<std::vector<size_t>>;
//! get modifications to be taken under some specific constraints
class ModifyActionPlanner;
//! search best modify action for opr seq on a single comp node
class ActionSearcherSingleCN;
struct Opr;
struct Var;
struct InternalDeleter {
void operator()(ActionSearcherSingleCN*) const;
void operator()(ModifyActionPlanner*) const;
......@@ -67,32 +71,8 @@ class SeqModifierForSublinearMemory {
//! thread pool to run ModifyActionPlanner
FutureThreadPool<void> m_planner_thread_pool;
//! map from original var to replaced var
ThinHashMap<VarNode*, VarNode*> m_var_map;
VarNodeArray m_new_inputs; //!< setup by replace_vars
MemoryOptimizerHelper m_mem_opt;
ComputingGraphImpl* const m_owner_graph = nullptr;
CompNode::UnorderedMap<size_t> m_prev_min_bottleneck;
/*!
* \brief replace input vars according to m_var_map, and store results in
* m_new_inputs;
* \return whether any var is changed
*/
bool replace_vars(const VarNodeArray& inputs);
/*!
* \brief copy opr and set inputs to m_new_inputs, and add outputs in
* m_var_map
* \return new operator
*/
OperatorNodeBase* copy_opr_from_new_inputs(OperatorNodeBase* opr,
bool recomp);
//! restore computing sequence and modify operator priority
void reset_opr_seq(const OprNodeArray& oprseq);
......@@ -107,27 +87,6 @@ class SeqModifierForSublinearMemory {
return std::make_shared<SplitPointSet::element_type>(
std::forward<Args>(args)...);
}
public:
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g);
//! see memory_optimizer set_priority_before_opt
void set_priority_before_opt(const VarNodeArray& endpoints) {
m_mem_opt.set_priority_before_opt(endpoints);
}
//! see memory_optimizer restore_graph_option
void restore_graph_option() {
m_mem_opt.restore_graph_option();
}
//! replace endpoint vars by the ones that require more computing
void modify_endpoint_vars(VarNodeArray& endpoints);
//! check whether actual opr_seq is what we expect; throw InternalError
void sanity_check(const OprNodeArray& opr_seq);
const CompNode::UnorderedMap<size_t>& prev_min_bottleneck();
};
} // namespace cg
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册