提交 65d554ed 编写于 作者: M Megvii Engine Team

refactor(dtr): improve dtr in computing graph

GitOrigin-RevId: d599a08e52c4feebc98b016dfd8073377c888efc
上级 4f28e146
...@@ -9,8 +9,32 @@ ...@@ -9,8 +9,32 @@
class DTRConfig: class DTRConfig:
r"""Configuration for DTR memory optimization.
Args:
eviction_threshold: eviction threshold in bytes. When GPU memory usage
exceeds this value, DTR will heuristically select and evict resident
tensors until the amount of used memory falls below this threshold.
evictee_minimum_size: memory threshold of tensors in bytes. Only tensors
whose size exceeds this threshold will be added to the candidate set.
A tensor that is not added to the candidate set will never be evicted
during its lifetime. Default: 1048576.
recomp_memory_factor: hyperparameter of the estimated memory of recomputing
the tensor. The larger this value is, the less memory-consuming
tensor will be evicted in heuristic strategies. This value is greater
than or equal to 0. Default: 1.
recomp_time_factor: hyperparameter of the estimated time of recomputing
the tensor. The larger this value is, the less time-consuming
tensor will be evicted in heuristic strategies. This value is greater
than or equal to 0. Default: 1.
"""
def __init__( def __init__(
self, eviction_threshold: int = 0, evictee_minimum_size: int = 1 << 20 self,
eviction_threshold: int = 0,
evictee_minimum_size: int = 1 << 20,
recomp_memory_factor: float = 1,
recomp_time_factor: float = 1,
): ):
assert eviction_threshold > 0, "eviction_threshold must be greater to zero" assert eviction_threshold > 0, "eviction_threshold must be greater to zero"
self.eviction_threshold = eviction_threshold self.eviction_threshold = eviction_threshold
...@@ -18,3 +42,11 @@ class DTRConfig: ...@@ -18,3 +42,11 @@ class DTRConfig:
evictee_minimum_size >= 0 evictee_minimum_size >= 0
), "evictee_minimum_size must be greater or equal to zero" ), "evictee_minimum_size must be greater or equal to zero"
self.evictee_minimum_size = evictee_minimum_size self.evictee_minimum_size = evictee_minimum_size
assert (
recomp_memory_factor >= 0
), "recomp_memory_factor must be greater or equal to zero"
self.recomp_memory_factor = recomp_memory_factor
assert (
recomp_time_factor >= 0
), "recomp_time_factor must be greater or equal to zero"
self.recomp_time_factor = recomp_time_factor
...@@ -528,6 +528,12 @@ class trace: ...@@ -528,6 +528,12 @@ class trace:
graph.options.dtr_config.evictee_minimum_size = ( graph.options.dtr_config.evictee_minimum_size = (
self._dtr_config.evictee_minimum_size self._dtr_config.evictee_minimum_size
) )
graph.options.dtr_config.recomp_memory_factor = (
self._dtr_config.recomp_memory_factor
)
graph.options.dtr_config.recomp_time_factor = (
self._dtr_config.recomp_time_factor
)
# graph optimization # graph optimization
if self._graph_opt_config is not None: if self._graph_opt_config is not None:
mapping = {None: 0, False: 1, True: 2} mapping = {None: 0, False: 1, True: 2}
......
...@@ -485,7 +485,9 @@ void init_graph_rt(py::module m) { ...@@ -485,7 +485,9 @@ void init_graph_rt(py::module m) {
py::class_<cg::ComputingGraph::Options::DTRConfig>(PyComputingGraphOptions, "DTRConfig") py::class_<cg::ComputingGraph::Options::DTRConfig>(PyComputingGraphOptions, "DTRConfig")
DEF_READWRITE(eviction_threshold) DEF_READWRITE(eviction_threshold)
DEF_READWRITE(evictee_minimum_size); DEF_READWRITE(evictee_minimum_size)
DEF_READWRITE(recomp_memory_factor)
DEF_READWRITE(recomp_time_factor);
#undef CURRENT_CLASS #undef CURRENT_CLASS
auto common = rel_import("common", m, 1); auto common = rel_import("common", m, 1);
......
...@@ -105,6 +105,7 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -105,6 +105,7 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
ThinHashSet<Var*> alive_vars; ThinHashSet<Var*> alive_vars;
size_t cur_usage = 0; size_t cur_usage = 0;
size_t cur_op_cnt = 0;
//! map from original var to latest var //! map from original var to latest var
ThinHashMap<VarNode*, Var*> latest_var; ThinHashMap<VarNode*, Var*> latest_var;
...@@ -125,7 +126,9 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -125,7 +126,9 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
auto size = var->size; auto size = var->size;
mgb_assert(size <= cur_usage); mgb_assert(size <= cur_usage);
cur_usage -= size; cur_usage -= size;
return true;
} }
return false;
}; };
auto get_latest = [&](Var* var) { auto get_latest = [&](Var* var) {
...@@ -137,11 +140,10 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -137,11 +140,10 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
} }
}; };
double est_time = 0;
ThinHashMap<Var*, double> dfs_back; ThinHashMap<Var*, double> dfs_back;
ThinHashMap<Var*, double> dfs_ops;
ThinHashMap<Var*, double> dfs_front; ThinHashMap<Var*, double> dfs_front;
ThinHashMap<Var*, double> dfs_mem;
auto regen_time = [&](Var* var) { auto regen_time = [&](Var* var) {
thin_function<double(Var*)> dfs_b; thin_function<double(Var*)> dfs_b;
thin_function<double(Var*)> dfs_f; thin_function<double(Var*)> dfs_f;
...@@ -180,21 +182,67 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -180,21 +182,67 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
return dfs_f(var) * dfs_b(var); return dfs_f(var) * dfs_b(var);
}; };
auto regen_mem = [&](Var* var) {
thin_function<double(Var*)> dfs_b;
dfs_b = [&](Var* var) {
if (dfs_mem.find(var) != dfs_mem.end()) {
return dfs_mem[var];
}
auto opr = var->owner_opr();
double mem_sum = var->size;
for (auto i : opr->input) {
auto ivar = get_latest(i);
if (need_regen(ivar)) {
mem_sum += dfs_b(ivar);
}
}
dfs_mem[var] = mem_sum;
return mem_sum;
};
return dfs_b(var);
};
auto next_used = [&](Var* var) {
var = get_latest(var);
size_t t = DUPOPR_TIME;
for (auto rec : var->access_rec) {
if (rec.time > cur_op_cnt - 1 && rec.time < t)
t = rec.time;
}
if (t < DUPOPR_TIME) {
return t + 1 - cur_op_cnt;
} else {
return t;
}
};
double tim_factor = 1;
double mem_factor = 1;
if (config->recomp_memory_factor >= 0) {
mem_factor = config->recomp_memory_factor;
}
if (config->recomp_time_factor >= 0) {
tim_factor = config->recomp_time_factor;
}
static constexpr double MAX_EVAL_VALUE = std::numeric_limits<double>::max(); static constexpr double MAX_EVAL_VALUE = std::numeric_limits<double>::max();
auto find_best = [&]() { auto find_best = [&]() {
Var* best = nullptr; Var* best = nullptr;
double min_eval_value = MAX_EVAL_VALUE; double min_eval_value = MAX_EVAL_VALUE;
dfs_back.clear(); dfs_back.clear();
dfs_front.clear(); dfs_front.clear();
dfs_mem.clear();
for (auto var : alive_vars) { for (auto var : alive_vars) {
if (var->size < config->evictee_minimum_size if (var->size < config->evictee_minimum_size
|| pin[var->orig_var] > 0 || pin[var->orig_var] > 0
|| is_bad_opr(var->owner_opr()->orig_opr)) { || is_bad_opr(var->owner_opr()->orig_opr)) {
continue; continue;
} }
double regen = regen_time(var); double regen_t = regen_time(var);
double eval_value = regen / static_cast<double>(var->size) double regen_m = regen_mem(var);
/ (est_time - var->last_access_time + 1e-8); double eval_value = pow(regen_t, tim_factor) * pow(regen_m, mem_factor)
/ static_cast<double>(var->size)
/ next_used(var);
if (eval_value < min_eval_value) { if (eval_value < min_eval_value) {
min_eval_value = eval_value; min_eval_value = eval_value;
best = var; best = var;
...@@ -207,7 +255,18 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -207,7 +255,18 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
remove_alive(var); remove_alive(var);
}; };
thin_function<void(Var*)> recursive_free;
auto auto_evict = [&](size_t needed) { auto auto_evict = [&](size_t needed) {
// proactively remove end-of-life vars
std::vector<Var*> to_free(0);
for (auto i : alive_vars) {
if (next_used(get_latest(i)) == DUPOPR_TIME && pin[i->orig_var]==0) {
to_free.push_back(get_latest(i));
}
}
for (auto i : to_free) {
recursive_free(get_latest(i));
}
while (cur_usage + needed >= config->eviction_threshold) { while (cur_usage + needed >= config->eviction_threshold) {
Var* v = find_best(); Var* v = find_best();
if (!v) { if (!v) {
...@@ -217,6 +276,25 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -217,6 +276,25 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
} }
}; };
recursive_free = [&](Var* var) {
if (pin[var->orig_var] > 0) return;
auto opr = var->owner_opr();
bool need = false;
for (auto i : var->access_rec) {
if (i.time >= cur_op_cnt) {
need = true;
break;
}
}
if (!need) {
if (remove_alive(var)) {
for (auto i : opr->input) {
recursive_free(get_latest(i));
}
}
}
};
thin_function<Var*(Opr*, Var*)> regenerate; thin_function<Var*(Opr*, Var*)> regenerate;
regenerate = [&](Opr* reader, Var* var) { regenerate = [&](Opr* reader, Var* var) {
auto opr = var->owner_opr(); auto opr = var->owner_opr();
...@@ -230,17 +308,11 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -230,17 +308,11 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
new_opr->input.reserve(opr->input.size()); new_opr->input.reserve(opr->input.size());
new_opr->output.reserve(opr->output.size()); new_opr->output.reserve(opr->output.size());
new_opr->estimate_compute_time = opr->estimate_compute_time;
for (auto i : opr->input) { for (auto i : opr->input) {
i->last_access_time = est_time;
pin[i->orig_var] ++; pin[i->orig_var] ++;
} }
for (auto o : opr->output) {
auto lo = get_latest(o);
if (!need_regen(lo)) {
remove_alive(lo);
}
}
for (auto i : opr->input) { for (auto i : opr->input) {
auto ivar = get_latest(i); auto ivar = get_latest(i);
if (need_regen(ivar)) { if (need_regen(ivar)) {
...@@ -264,15 +336,18 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -264,15 +336,18 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
new_opr); new_opr);
ovar->recomp_id = lo->recomp_id + 1; ovar->recomp_id = lo->recomp_id + 1;
new_opr->output.push_back(ovar.get()); new_opr->output.push_back(ovar.get());
if (o == var) { if (need_regen(lo)) { // latest output is not in memory
new_var = ovar.get(); if (o == var) {
new_var = ovar.get();
for (size_t i = 1; i < lo->access_rec.size(); i ++) {
new_var->access_rec.push_back(lo->access_rec[i]);
}
add_alive(new_var);
latest_var[o->orig_var] = new_var;
}
} }
add_alive(ovar.get());
ovar->last_access_time = est_time;
latest_var[o->orig_var] = ovar.get();
var_storage().emplace_back(std::move(ovar)); var_storage().emplace_back(std::move(ovar));
} }
est_time += opr->estimate_compute_time;
for (auto i : opr->input) { for (auto i : opr->input) {
pin[i->orig_var] --; pin[i->orig_var] --;
} }
...@@ -280,34 +355,44 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo ...@@ -280,34 +355,44 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
}; };
for (size_t j = 0; j < seq().size(); ++j) { for (size_t j = 0; j < seq().size(); ++j) {
++ cur_op_cnt;
auto opr = seq()[j].get(); auto opr = seq()[j].get();
for (auto i : opr->input) { for (auto i : opr->input) {
pin[i->orig_var] ++; pin[i->orig_var] ++;
} }
for (auto i : opr->inputs_size) {
if (i > 0) cur_usage += i;
}
for (auto i : opr->input) { for (auto i : opr->input) {
i = get_latest(i); i = get_latest(i);
if (need_regen(i)) { if (need_regen(i)) {
i = regenerate(opr, i); i = regenerate(opr, i);
} }
i->last_access_time = est_time;
} }
size_t needed = 0; size_t needed = 0;
for (auto o : opr->output) { for (auto o : opr->output) {
needed += o->size; needed += o->size;
} }
auto_evict(needed); auto_evict(needed);
est_time += opr->estimate_compute_time;
for (auto o : opr->output) { for (auto o : opr->output) {
o = get_latest(o);
add_alive(o); add_alive(o);
o->last_access_time = est_time;
} }
for (auto i : opr->input) { for (auto i : opr->input) {
pin[i->orig_var] --; pin[i->orig_var] --;
} }
for (auto i : opr->input) { for (auto i : opr->input) {
i = get_latest(i); if (opr == i->last_access_opr()) {
if (opr == i->last_access_opr()) recursive_free(get_latest(i));
remove_alive(i); }
}
for (auto o : opr->output) {
if (opr == o->last_access_opr()) {
recursive_free(get_latest(o));
}
}
for (auto i : opr->inputs_size) {
if (i < 0) cur_usage += i;
} }
} }
for (size_t j = 0; j < seq().size(); ++j) { for (size_t j = 0; j < seq().size(); ++j) {
......
...@@ -26,6 +26,7 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_ ...@@ -26,6 +26,7 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_
m_nr_endpoint_oprs = 0; m_nr_endpoint_oprs = 0;
ThinHashMap<VarNode*, Var*> varmap; ThinHashMap<VarNode*, Var*> varmap;
ThinHashMap<VarNode*, Opr*> var_used;
for (auto orig_opr : *m_orig_opr_seq) { for (auto orig_opr : *m_orig_opr_seq) {
auto time = m_seq.size(); auto time = m_seq.size();
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time)); m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time));
...@@ -38,6 +39,11 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_ ...@@ -38,6 +39,11 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_
auto iter = varmap.find(dep.first); auto iter = varmap.find(dep.first);
if (iter == varmap.end()) { if (iter == varmap.end()) {
// input var needs not to be considered // input var needs not to be considered
size_t size = dep.first->dtype().size(dep.first->shape().total_nr_elems());
if (!var_used[dep.first]) {
opr->inputs_size.push_back(size);
}
var_used[dep.first] = opr;
continue; continue;
} }
...@@ -75,7 +81,10 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_ ...@@ -75,7 +81,10 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_
} }
mgb_assert(!opr->output.empty()); mgb_assert(!opr->output.empty());
} }
for (auto x : var_used) {
size_t size = x.first->dtype().size(x.first->shape().total_nr_elems());
var_used[x.first]->inputs_size.push_back(-static_cast<ptrdiff_t>(size));
}
if (remove_unused_output) { if (remove_unused_output) {
for (auto&& i : m_seq) { for (auto&& i : m_seq) {
auto&& oarr = i->output; auto&& oarr = i->output;
...@@ -156,4 +165,4 @@ OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs( ...@@ -156,4 +165,4 @@ OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs(
#endif // MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR #endif // MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
...@@ -163,6 +163,8 @@ struct SeqModifierBase::Opr { ...@@ -163,6 +163,8 @@ struct SeqModifierBase::Opr {
//! by make_discard_plan() //! by make_discard_plan()
size_t block_begin_time = 0, block_end_time = 0; size_t block_begin_time = 0, block_end_time = 0;
std::vector<ptrdiff_t> inputs_size;
Opr(OperatorNodeBase* opr, size_t t) Opr(OperatorNodeBase* opr, size_t t)
: orig_opr{opr}, : orig_opr{opr},
time{t}, time{t},
...@@ -176,7 +178,6 @@ struct SeqModifierBase::Var { ...@@ -176,7 +178,6 @@ struct SeqModifierBase::Var {
VarNode* const orig_var; VarNode* const orig_var;
size_t size; //!< memory usage in bytes of this var size_t size; //!< memory usage in bytes of this var
size_t recomp_id = 0; size_t recomp_id = 0;
double last_access_time = 0;
//! write or read access of a var //! write or read access of a var
struct AccessRecord { struct AccessRecord {
...@@ -233,4 +234,4 @@ struct SeqModifierBase::Var { ...@@ -233,4 +234,4 @@ struct SeqModifierBase::Var {
#endif // MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR #endif // MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
...@@ -454,6 +454,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -454,6 +454,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
struct DTRConfig { struct DTRConfig {
size_t eviction_threshold = 0; size_t eviction_threshold = 0;
size_t evictee_minimum_size = 1ULL << 20; size_t evictee_minimum_size = 1ULL << 20;
double recomp_memory_factor = 1;
double recomp_time_factor = 1;
} dtr_config; } dtr_config;
//! do not re-profile to select best impl algo when input shape //! do not re-profile to select best impl algo when input shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册