提交 65d554ed 编写于 作者: M Megvii Engine Team

refactor(dtr): improve dtr in computing graph

GitOrigin-RevId: d599a08e52c4feebc98b016dfd8073377c888efc
上级 4f28e146
......@@ -9,8 +9,32 @@
class DTRConfig:
r"""Configuration for DTR memory optimization.
Args:
eviction_threshold: eviction threshold in bytes. When GPU memory usage
exceeds this value, DTR will heuristically select and evict resident
tensors until the amount of used memory falls below this threshold.
evictee_minimum_size: memory threshold of tensors in bytes. Only tensors
whose size exceeds this threshold will be added to the candidate set.
A tensor that is not added to the candidate set will never be evicted
during its lifetime. Default: 1048576.
recomp_memory_factor: hyperparameter of the estimated memory of recomputing
the tensor. The larger this value is, the less memory-consuming
tensor will be evicted in heuristic strategies. This value is greater
than or equal to 0. Default: 1.
recomp_time_factor: hyperparameter of the estimated time of recomputing
the tensor. The larger this value is, the less time-consuming
tensor will be evicted in heuristic strategies. This value is greater
than or equal to 0. Default: 1.
"""
def __init__(
self, eviction_threshold: int = 0, evictee_minimum_size: int = 1 << 20
self,
eviction_threshold: int = 0,
evictee_minimum_size: int = 1 << 20,
recomp_memory_factor: float = 1,
recomp_time_factor: float = 1,
):
assert eviction_threshold > 0, "eviction_threshold must be greater to zero"
self.eviction_threshold = eviction_threshold
......@@ -18,3 +42,11 @@ class DTRConfig:
evictee_minimum_size >= 0
), "evictee_minimum_size must be greater or equal to zero"
self.evictee_minimum_size = evictee_minimum_size
assert (
recomp_memory_factor >= 0
), "recomp_memory_factor must be greater or equal to zero"
self.recomp_memory_factor = recomp_memory_factor
assert (
recomp_time_factor >= 0
), "recomp_time_factor must be greater or equal to zero"
self.recomp_time_factor = recomp_time_factor
......@@ -528,6 +528,12 @@ class trace:
graph.options.dtr_config.evictee_minimum_size = (
self._dtr_config.evictee_minimum_size
)
graph.options.dtr_config.recomp_memory_factor = (
self._dtr_config.recomp_memory_factor
)
graph.options.dtr_config.recomp_time_factor = (
self._dtr_config.recomp_time_factor
)
# graph optimization
if self._graph_opt_config is not None:
mapping = {None: 0, False: 1, True: 2}
......
......@@ -485,7 +485,9 @@ void init_graph_rt(py::module m) {
py::class_<cg::ComputingGraph::Options::DTRConfig>(PyComputingGraphOptions, "DTRConfig")
DEF_READWRITE(eviction_threshold)
DEF_READWRITE(evictee_minimum_size);
DEF_READWRITE(evictee_minimum_size)
DEF_READWRITE(recomp_memory_factor)
DEF_READWRITE(recomp_time_factor);
#undef CURRENT_CLASS
auto common = rel_import("common", m, 1);
......
......@@ -105,6 +105,7 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
ThinHashSet<Var*> alive_vars;
size_t cur_usage = 0;
size_t cur_op_cnt = 0;
//! map from original var to latest var
ThinHashMap<VarNode*, Var*> latest_var;
......@@ -125,7 +126,9 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
auto size = var->size;
mgb_assert(size <= cur_usage);
cur_usage -= size;
return true;
}
return false;
};
auto get_latest = [&](Var* var) {
......@@ -137,11 +140,10 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
}
};
double est_time = 0;
ThinHashMap<Var*, double> dfs_back;
ThinHashMap<Var*, double> dfs_ops;
ThinHashMap<Var*, double> dfs_front;
ThinHashMap<Var*, double> dfs_mem;
auto regen_time = [&](Var* var) {
thin_function<double(Var*)> dfs_b;
thin_function<double(Var*)> dfs_f;
......@@ -180,21 +182,67 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
return dfs_f(var) * dfs_b(var);
};
auto regen_mem = [&](Var* var) {
thin_function<double(Var*)> dfs_b;
dfs_b = [&](Var* var) {
if (dfs_mem.find(var) != dfs_mem.end()) {
return dfs_mem[var];
}
auto opr = var->owner_opr();
double mem_sum = var->size;
for (auto i : opr->input) {
auto ivar = get_latest(i);
if (need_regen(ivar)) {
mem_sum += dfs_b(ivar);
}
}
dfs_mem[var] = mem_sum;
return mem_sum;
};
return dfs_b(var);
};
auto next_used = [&](Var* var) {
var = get_latest(var);
size_t t = DUPOPR_TIME;
for (auto rec : var->access_rec) {
if (rec.time > cur_op_cnt - 1 && rec.time < t)
t = rec.time;
}
if (t < DUPOPR_TIME) {
return t + 1 - cur_op_cnt;
} else {
return t;
}
};
double tim_factor = 1;
double mem_factor = 1;
if (config->recomp_memory_factor >= 0) {
mem_factor = config->recomp_memory_factor;
}
if (config->recomp_time_factor >= 0) {
tim_factor = config->recomp_time_factor;
}
static constexpr double MAX_EVAL_VALUE = std::numeric_limits<double>::max();
auto find_best = [&]() {
Var* best = nullptr;
double min_eval_value = MAX_EVAL_VALUE;
dfs_back.clear();
dfs_front.clear();
dfs_mem.clear();
for (auto var : alive_vars) {
if (var->size < config->evictee_minimum_size
|| pin[var->orig_var] > 0
|| is_bad_opr(var->owner_opr()->orig_opr)) {
continue;
}
double regen = regen_time(var);
double eval_value = regen / static_cast<double>(var->size)
/ (est_time - var->last_access_time + 1e-8);
double regen_t = regen_time(var);
double regen_m = regen_mem(var);
double eval_value = pow(regen_t, tim_factor) * pow(regen_m, mem_factor)
/ static_cast<double>(var->size)
/ next_used(var);
if (eval_value < min_eval_value) {
min_eval_value = eval_value;
best = var;
......@@ -207,7 +255,18 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
remove_alive(var);
};
thin_function<void(Var*)> recursive_free;
auto auto_evict = [&](size_t needed) {
// proactively remove end-of-life vars
std::vector<Var*> to_free(0);
for (auto i : alive_vars) {
if (next_used(get_latest(i)) == DUPOPR_TIME && pin[i->orig_var]==0) {
to_free.push_back(get_latest(i));
}
}
for (auto i : to_free) {
recursive_free(get_latest(i));
}
while (cur_usage + needed >= config->eviction_threshold) {
Var* v = find_best();
if (!v) {
......@@ -217,6 +276,25 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
}
};
recursive_free = [&](Var* var) {
if (pin[var->orig_var] > 0) return;
auto opr = var->owner_opr();
bool need = false;
for (auto i : var->access_rec) {
if (i.time >= cur_op_cnt) {
need = true;
break;
}
}
if (!need) {
if (remove_alive(var)) {
for (auto i : opr->input) {
recursive_free(get_latest(i));
}
}
}
};
thin_function<Var*(Opr*, Var*)> regenerate;
regenerate = [&](Opr* reader, Var* var) {
auto opr = var->owner_opr();
......@@ -230,17 +308,11 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
new_opr->input.reserve(opr->input.size());
new_opr->output.reserve(opr->output.size());
new_opr->estimate_compute_time = opr->estimate_compute_time;
for (auto i : opr->input) {
i->last_access_time = est_time;
pin[i->orig_var] ++;
}
for (auto o : opr->output) {
auto lo = get_latest(o);
if (!need_regen(lo)) {
remove_alive(lo);
}
}
for (auto i : opr->input) {
auto ivar = get_latest(i);
if (need_regen(ivar)) {
......@@ -264,15 +336,18 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
new_opr);
ovar->recomp_id = lo->recomp_id + 1;
new_opr->output.push_back(ovar.get());
if (need_regen(lo)) { // latest output is not in memory
if (o == var) {
new_var = ovar.get();
for (size_t i = 1; i < lo->access_rec.size(); i ++) {
new_var->access_rec.push_back(lo->access_rec[i]);
}
add_alive(new_var);
latest_var[o->orig_var] = new_var;
}
}
add_alive(ovar.get());
ovar->last_access_time = est_time;
latest_var[o->orig_var] = ovar.get();
var_storage().emplace_back(std::move(ovar));
}
est_time += opr->estimate_compute_time;
for (auto i : opr->input) {
pin[i->orig_var] --;
}
......@@ -280,34 +355,44 @@ SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perfo
};
for (size_t j = 0; j < seq().size(); ++j) {
++ cur_op_cnt;
auto opr = seq()[j].get();
for (auto i : opr->input) {
pin[i->orig_var] ++;
}
for (auto i : opr->inputs_size) {
if (i > 0) cur_usage += i;
}
for (auto i : opr->input) {
i = get_latest(i);
if (need_regen(i)) {
i = regenerate(opr, i);
}
i->last_access_time = est_time;
}
size_t needed = 0;
for (auto o : opr->output) {
needed += o->size;
}
auto_evict(needed);
est_time += opr->estimate_compute_time;
for (auto o : opr->output) {
o = get_latest(o);
add_alive(o);
o->last_access_time = est_time;
}
for (auto i : opr->input) {
pin[i->orig_var] --;
}
for (auto i : opr->input) {
i = get_latest(i);
if (opr == i->last_access_opr())
remove_alive(i);
if (opr == i->last_access_opr()) {
recursive_free(get_latest(i));
}
}
for (auto o : opr->output) {
if (opr == o->last_access_opr()) {
recursive_free(get_latest(o));
}
}
for (auto i : opr->inputs_size) {
if (i < 0) cur_usage += i;
}
}
for (size_t j = 0; j < seq().size(); ++j) {
......
......@@ -26,6 +26,7 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_
m_nr_endpoint_oprs = 0;
ThinHashMap<VarNode*, Var*> varmap;
ThinHashMap<VarNode*, Opr*> var_used;
for (auto orig_opr : *m_orig_opr_seq) {
auto time = m_seq.size();
m_seq.emplace_back(m_opr_mempool.alloc_unique(orig_opr, time));
......@@ -38,6 +39,11 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_
auto iter = varmap.find(dep.first);
if (iter == varmap.end()) {
// input var needs not to be considered
size_t size = dep.first->dtype().size(dep.first->shape().total_nr_elems());
if (!var_used[dep.first]) {
opr->inputs_size.push_back(size);
}
var_used[dep.first] = opr;
continue;
}
......@@ -75,7 +81,10 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_
}
mgb_assert(!opr->output.empty());
}
for (auto x : var_used) {
size_t size = x.first->dtype().size(x.first->shape().total_nr_elems());
var_used[x.first]->inputs_size.push_back(-static_cast<ptrdiff_t>(size));
}
if (remove_unused_output) {
for (auto&& i : m_seq) {
auto&& oarr = i->output;
......
......@@ -163,6 +163,8 @@ struct SeqModifierBase::Opr {
//! by make_discard_plan()
size_t block_begin_time = 0, block_end_time = 0;
std::vector<ptrdiff_t> inputs_size;
Opr(OperatorNodeBase* opr, size_t t)
: orig_opr{opr},
time{t},
......@@ -176,7 +178,6 @@ struct SeqModifierBase::Var {
VarNode* const orig_var;
size_t size; //!< memory usage in bytes of this var
size_t recomp_id = 0;
double last_access_time = 0;
//! write or read access of a var
struct AccessRecord {
......
......@@ -454,6 +454,8 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
struct DTRConfig {
size_t eviction_threshold = 0;
size_t evictee_minimum_size = 1ULL << 20;
double recomp_memory_factor = 1;
double recomp_time_factor = 1;
} dtr_config;
//! do not re-profile to select best impl algo when input shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册