提交 355153e1 编写于 作者: M Megvii Engine Team

feat(mge/dtr): add DTR in computing graph

GitOrigin-RevId: 8941810319ecc54e6751e0f584156d26d84ab1e2
上级 76f4f975
......@@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import (
set_cpp_apply_const_with_tracing,
set_cpp_apply_with_tracing,
)
from .dtr_config import DTRConfig
from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import (
apply_const_with_tracing,
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
class DTRConfig:
def __init__(
self, eviction_threshold: int = 0, evictee_minimum_size: int = 1 << 20
):
assert eviction_threshold > 0, "eviction_threshold must be greater to zero"
self.eviction_threshold = eviction_threshold
assert (
evictee_minimum_size >= 0
), "evictee_minimum_size must be greater or equal to zero"
self.evictee_minimum_size = evictee_minimum_size
......@@ -37,6 +37,7 @@ from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from ..utils.naming import AutoNaming
from .dtr_config import DTRConfig
from .sublinear_memory_config import SublinearMemoryConfig
......@@ -142,6 +143,7 @@ class trace:
symbolic=False,
capture_as_const=False,
sublinear_memory_config: SublinearMemoryConfig = None,
dtr_config: DTRConfig = None,
profiling: bool = False,
opt_level: int = 2,
symbolic_shape: bool = True,
......@@ -150,6 +152,7 @@ class trace:
self._symbolic = symbolic
self._capture_as_const = capture_as_const
self._sublinear_memory_config = sublinear_memory_config
self._dtr_config = dtr_config
self._profiling = profiling
self._profiler = None
self._graph_opt_level = opt_level
......@@ -491,6 +494,15 @@ class trace:
graph.options.no_force_inplace = True
graph.options.seq_opt.enable_seq_comp_node_opt = False
graph.options.graph_opt_level = self._graph_opt_level
if self._dtr_config is not None:
graph.options.enable_dtr_memory_opt = True
graph.options.dtr_config.eviction_threshold = (
self._dtr_config.eviction_threshold
)
graph.options.dtr_config.evictee_minimum_size = (
self._dtr_config.evictee_minimum_size
)
# sublinear
if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True
......
......@@ -395,6 +395,7 @@ void init_graph_rt(py::module m) {
DEF_READWRITE(allocate_static_mem_after_graph_compile)
DEF_READWRITE(fake_next_exec)
DEF_READWRITE(enable_sublinear_memory_opt)
DEF_READWRITE(enable_dtr_memory_opt)
DEF_READWRITE(no_profiling_on_shape_change)
DEF_READWRITE(enable_var_mem_defragment)
DEF_READWRITE(enable_grad_var_static_reshape)
......@@ -402,6 +403,7 @@ void init_graph_rt(py::module m) {
DEF_READWRITE(comp_node_seq_record_level)
DEF_READWRITE(no_force_inplace)
DEF_READWRITE(sublinear_mem_config)
DEF_READWRITE(dtr_config)
// DEF_READWRITE(eager_evaluation)
// DEF_READWRITE(imperative_proxy_graph)
// DEF_READWRITE(extra_vardeps)
......@@ -434,6 +436,14 @@ void init_graph_rt(py::module m) {
DEF_READWRITE(lb_memory)
DEF_READWRITE(num_worker);
#undef CURRENT_CLASS
#define CURRENT_CLASS cg::ComputingGraph::Options::DTRConfig
py::class_<cg::ComputingGraph::Options::DTRConfig>(PyComputingGraphOptions, "DTRConfig")
DEF_READWRITE(eviction_threshold)
DEF_READWRITE(evictee_minimum_size);
#undef CURRENT_CLASS
auto common = rel_import("common", m, 1);
......
......@@ -250,6 +250,10 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
seq_modifier_for_sublinear_memory{owner,
&(owner->options().sublinear_mem_config)},
#endif
#if MGB_ENABLE_DTR
seq_modifier_for_dtr{owner,
&(owner->options().dtr_config)},
#endif
#if MGB_ENABLE_MEMORY_SWAP
memory_swap_support{owner},
#endif
......@@ -473,6 +477,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
#if MGB_ENABLE_SUBLINEAR
if (options().enable_sublinear_memory_opt) {
mgb_assert(!options().enable_dtr_memory_opt);
if (!sopr_stat.has_virtual_grad) {
mgb_log_debug(
"no virtual grad var; sublinear memory may produce "
......@@ -485,6 +490,15 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
mgb_assert(!options().enable_sublinear_memory_opt);
#endif // MGB_ENABLE_SUBLINEAR
#if MGB_ENABLE_DTR
if (options().enable_dtr_memory_opt) {
mgb_assert(!options().enable_sublinear_memory_opt);
seq_modifier_for_dtr().set_priority_before_opt(dest_vars);
}
#else
mgb_assert(!options().enable_dtr_memory_opt);
#endif // MGB_ENABLE_DTR
#if !MGB_BUILD_SLIM_SERVING
mgb_assert(!options().eager_evaluation,
"attempt to compile eager_evaluation graph");
......@@ -558,7 +572,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
CompSeqExtraInfo extra_info;
cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
bool init_flag = false;
auto init_opr_seq = [&]() {
mgb_assert(!init_flag);
init_flag = true;
ThinHashMap<VarNode*, size_t> var2idx;
std::unordered_map<CallbackCallerKey, CallbackCallerVal,
CallbackCallerKey::Hash>
......@@ -629,6 +646,15 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
mgb_assert(!options().enable_memory_swap);
#endif
#if MGB_ENABLE_DTR
if (options().enable_dtr_memory_opt) {
MGB_TRY {
seq_modifier_for_dtr().modify_endpoint_vars(dest_vars);
init_opr_seq();
}
MGB_FINALLY(seq_modifier_for_dtr().restore_graph_option());
}
#endif
#if MGB_ENABLE_SUBLINEAR
if (options().enable_sublinear_memory_opt) {
MGB_TRY {
......@@ -650,12 +676,11 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
*/
seq_modifier_for_sublinear_memory().restore_graph_option());
seq_modifier_for_sublinear_memory().sanity_check(*opr_seq);
} else {
init_opr_seq();
}
#else
init_opr_seq();
#endif // MGB_ENABLE_SUBLINEAR
if (!init_flag) {
init_opr_seq();
}
return {std::move(extra_info), opr_seq, std::move(dest_vars)};
}
......@@ -751,6 +776,13 @@ ComputingGraphImpl::seq_modifier_for_sublinear_memory() {
}
#endif
#if MGB_ENABLE_DTR
SeqModifierForDTR&
ComputingGraphImpl::seq_modifier_for_dtr() {
return components().seq_modifier_for_dtr;
}
#endif
void ComputingGraphImpl::share_device_memory_with(ComputingGraph& other) {
mgb_assert(
!m_current_comp_seq,
......
......@@ -15,6 +15,7 @@
#include "./grad_manager.h"
#include "./graph_opt.h"
#include "./seq_comp_node_opt_impl.h"
#include "./seq_dtr.h"
#include "./seq_sublinear_memory.h"
#include "./static_infer_impl.h"
#include "./swap/memory_swap.h"
......@@ -80,6 +81,9 @@ class ComputingGraphImpl final : public ComputingGraph {
#if MGB_ENABLE_SUBLINEAR
SeqModifierForSublinearMemory seq_modifier_for_sublinear_memory;
#endif
#if MGB_ENABLE_DTR
SeqModifierForDTR seq_modifier_for_dtr;
#endif
#if MGB_ENABLE_MEMORY_SWAP
swap::MemorySwap memory_swap_support;
#endif
......@@ -218,6 +222,9 @@ public:
SeqModifierForSublinearMemory& seq_modifier_for_sublinear_memory();
#endif
#if MGB_ENABLE_DTR
SeqModifierForDTR& seq_modifier_for_dtr();
#endif
void share_device_memory_with(ComputingGraph& other) override;
void set_device_memory_allocator(
......
/**
* \file src/core/impl/graph/seq_dtr.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./seq_dtr.h"
#if MGB_ENABLE_DTR
using namespace mgb;
using namespace cg;
namespace {
bool is_bad_opr(OperatorNodeBase* opr) {
using F = OperatorNodeBase::NodeProp::Flag;
return opr->node_prop().contain(
F::IMPURE_FUNC | F::NO_AUTOMATIC_DUP | F::FORCE_UPDATE_INPUT_VAR);
}
} // namespace
class SeqModifierForDTR::ModifyActionPlanner : public ModifyActionPlannerBase {
public:
ModifyActionPlanner(SeqModifierBase* par) : ModifyActionPlannerBase{par} {}
void prepare(const OprNodeArray& opr_seq);
SeqModifyAction perform_dtr(CompNode comp_node, const OprNodeArray& seq, Config* config);
};
SeqModifierForDTR::SeqModifierForDTR(ComputingGraphImpl* owner, Config* config_g)
: SeqModifierBase(owner), m_config(config_g) {}
void SeqModifierForDTR::modify_endpoint_vars(VarNodeArray& endpoints) {
var_map().clear();
auto comp_seq = MemoryOptimizerHelper::CompSeq(owner_graph(), endpoints);
auto config =
MemoryOptimizerHelper::SubGraphConfig()
/*.add_bad_opr_flag(
OperatorNodeBase::NodeProp::Flag::IMPURE_FUNC)
.add_bad_opr_flag(
OperatorNodeBase::NodeProp::Flag::NO_AUTOMATIC_DUP)
.add_bad_opr_flag(OperatorNodeBase::NodeProp::Flag::
FORCE_UPDATE_INPUT_VAR)*/
// NOTE: it should not actually involve any opr with the above
// flags, but for better results, some ops(e.g. CudnnBatchNorm)
// should be involved and they are guaranteed to NEVER recompute.
.add_bad_var_flag(VarNode::Flag::VOLATILE_CONTENT)
.add_bad_var_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC)
.add_bad_var_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.add_bad_var_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE);
auto cn2oprseq = mem_opt().split_into_cn2oprseq(*comp_seq.m_seq, config);
if (cn2oprseq->empty()) {
return;
}
SeqModifyAction action;
ModifyActionPlanner* planner = new ModifyActionPlanner(this);
for (auto && i : *cn2oprseq) {
auto&& cur = planner->perform_dtr(i.first, i.second, m_config);
action.insert(cur.begin(), cur.end());
}
apply_action(action, *comp_seq.m_seq);
for (auto&& i : endpoints) {
auto iter = var_map().find(i);
if (iter != var_map().end()) {
i = iter->second;
}
}
}
void SeqModifierForDTR::ModifyActionPlanner::prepare(const OprNodeArray& opr_seq) {
init_seq(opr_seq, false);
for (size_t i = 0; i < seq().size(); ++i) {
auto opr = seq()[i].get();
size_t est = 0;
for (auto i : opr->input) {
est += i->size;
}
for (auto i : opr->output) {
est += i->size;
}
opr->estimate_compute_time = static_cast<double>(est) / 1e8;
}
}
SeqModifierForDTR::SeqModifyAction SeqModifierForDTR::ModifyActionPlanner::perform_dtr(
CompNode comp_node, const OprNodeArray& opr_seq, Config* config) {
prepare(opr_seq);
SeqModifyAction action;
if (comp_node.locator().stream < 0) {
// do not modify system stream oprs
return action;
}
ThinHashSet<Var*> alive_vars;
size_t cur_usage = 0;
//! map from original var to latest var
ThinHashMap<VarNode*, Var*> latest_var;
ThinHashMap<VarNode*, size_t> pin;
auto need_regen = [&](Var* var) {
return alive_vars.find(var) == alive_vars.end();
};
auto add_alive = [&](Var* var) {
auto&& ins = alive_vars.insert(var);
mgb_assert(ins.second);
cur_usage += var->size;
};
auto remove_alive = [&](Var* var) {
if (alive_vars.erase(var)) {
auto size = var->size;
mgb_assert(size <= cur_usage);
cur_usage -= size;
}
};
auto get_latest = [&](Var* var) {
auto iter = latest_var.find(var->orig_var);
if (iter == latest_var.end()) {
return var;
} else {
return iter->second;
}
};
double est_time = 0;
ThinHashMap<Var*, double> dfs_back;
ThinHashMap<Var*, double> dfs_front;
auto regen_time = [&](Var* var) {
thin_function<double(Var*)> dfs_b;
thin_function<double(Var*)> dfs_f;
dfs_b = [&](Var* var) {
if (dfs_back.find(var) != dfs_back.end()) {
return dfs_back[var];
}
auto opr = var->owner_opr();
double sum_time = opr->estimate_compute_time;
for (auto i : opr->input) {
auto ivar = get_latest(i);
if (need_regen(ivar)) {
sum_time += dfs_b(ivar);
}
}
dfs_back[var] = sum_time;
return sum_time;
};
dfs_f = [&](Var* var) {
if (dfs_front.find(var) != dfs_front.end()) {
return dfs_front[var];
}
double sum_time = 1;
for (size_t j = 1; j < var->access_rec.size();j ++) {
auto dep_opr = var->access_rec[j].opr;
for (auto o : dep_opr->output) {
o = get_latest(o);
if (need_regen(o)) {
sum_time += dfs_f(o);
}
}
}
dfs_front[var] = sum_time;
return sum_time;
};
return dfs_f(var) * dfs_b(var);
};
static constexpr double MAX_EVAL_VALUE = std::numeric_limits<double>::max();
auto find_best = [&]() {
Var* best = nullptr;
double min_eval_value = MAX_EVAL_VALUE;
dfs_back.clear();
dfs_front.clear();
for (auto var : alive_vars) {
if (var->size < config->evictee_minimum_size
|| pin[var->orig_var] > 0
|| is_bad_opr(var->owner_opr()->orig_opr)) {
continue;
}
double regen = regen_time(var);
double eval_value = regen / static_cast<double>(var->size)
/ (est_time - var->last_access_time + 1e-8);
if (eval_value < min_eval_value) {
min_eval_value = eval_value;
best = var;
}
}
return best;
};
auto do_evict = [&](Var* var) {
remove_alive(var);
};
auto auto_evict = [&](size_t needed) {
while (cur_usage + needed >= config->eviction_threshold) {
Var* v = find_best();
if (!v) {
break;
}
do_evict(v);
}
};
thin_function<Var*(Opr*, Var*)> regenerate;
regenerate = [&](Opr* reader, Var* var) {
auto opr = var->owner_opr();
// FIXME: if var can not be recomputed, the previous eviction may fail
if (is_bad_opr(opr->orig_opr)) {
return var;
}
auto new_opr_storage = opr_mempool().alloc_unique(opr->orig_opr, static_cast<size_t>(DUPOPR_TIME));
auto new_opr = new_opr_storage.get();
new_opr->input.reserve(opr->input.size());
new_opr->output.reserve(opr->output.size());
for (auto i : opr->input) {
i->last_access_time = est_time;
pin[i->orig_var] ++;
}
for (auto o : opr->output) {
auto lo = get_latest(o);
if (!need_regen(lo)) {
remove_alive(lo);
}
}
for (auto i : opr->input) {
auto ivar = get_latest(i);
if (need_regen(ivar)) {
ivar = regenerate(reader, ivar);
}
new_opr->input.push_back(ivar);
ivar->access_rec.emplace_back(new_opr);
}
reader->oprs_insert_before.emplace_back(std::move(new_opr_storage));
size_t needed = 0;
for (auto o : opr->output) {
needed += o->size;
}
auto_evict(needed);
Var* new_var = nullptr;
for (auto o : opr->output) {
auto lo = get_latest(o);
auto&& ovar = var_mempool().alloc_unique(lo->orig_var, lo->size,
new_opr);
ovar->recomp_id = lo->recomp_id + 1;
new_opr->output.push_back(ovar.get());
if (o == var) {
new_var = ovar.get();
}
add_alive(ovar.get());
ovar->last_access_time = est_time;
latest_var[o->orig_var] = ovar.get();
var_storage().emplace_back(std::move(ovar));
}
est_time += opr->estimate_compute_time;
for (auto i : opr->input) {
pin[i->orig_var] --;
}
return new_var;
};
for (size_t j = 0; j < seq().size(); ++j) {
auto opr = seq()[j].get();
for (auto i : opr->input) {
pin[i->orig_var] ++;
}
for (auto i : opr->input) {
i = get_latest(i);
if (need_regen(i)) {
i = regenerate(opr, i);
}
i->last_access_time = est_time;
}
size_t needed = 0;
for (auto o : opr->output) {
needed += o->size;
}
auto_evict(needed);
est_time += opr->estimate_compute_time;
for (auto o : opr->output) {
add_alive(o);
o->last_access_time = est_time;
}
for (auto i : opr->input) {
pin[i->orig_var] --;
}
for (auto i : opr->input) {
i = get_latest(i);
if (opr == i->last_access_opr())
remove_alive(i);
}
}
for (size_t j = 0; j < seq().size(); ++j) {
auto opr = seq()[j].get();
auto&& arr = opr->oprs_insert_before;
if (arr.empty()) {
continue;
}
auto&& dest = action[opr->orig_opr];
dest.reserve(arr.size());
for (auto&& i : arr) {
dest.push_back(i->orig_opr);
}
}
return action;
}
void SeqModifierForDTR::apply_action(SeqModifyAction& action,
const OprNodeArray& oprseq) {
auto cur_priority = std::numeric_limits<decltype(
OperatorNodeBase::NodeProp::Attribute::priority)>::min();
ThinHashSet<OperatorNodeBase*> modified_opr;
ThinHashMap<OperatorNodeBase*, size_t> recomp_id;
auto set_priority = [&](OperatorNodeBase* opr) {
mgb_assert(modified_opr.insert(opr).second);
mem_opt().set_priority(opr, cur_priority++);
};
auto on_opr_visited = [&](OperatorNodeBase* opr) {
if (replace_vars(opr->input())) {
recomp_id[opr] ++;
opr = copy_opr_from_new_inputs(opr, true, recomp_id[opr] - 1);
}
set_priority(opr);
};
DepOprIter dep_iter{on_opr_visited};
for (auto opr : oprseq) {
auto iter = action.find(opr);
if (iter != action.end()) {
for (auto i : iter->second) {
replace_vars(i->input());
recomp_id[i] ++;
auto opr_new = copy_opr_from_new_inputs(i, false, recomp_id[i] - 1);
set_priority(opr_new);
}
action.erase(iter);
}
dep_iter.add(opr);
}
mgb_assert(action.empty());
}
#endif // !MGB_ENABLE_DTR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/core/impl/graph/seq_dtr.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./memory_optimizer.h"
#include "./seq_modifier_base.h"
#include "megbrain/graph/cg.h"
#if MGB_ENABLE_DTR
namespace mgb {
namespace cg {
class SeqModifierForDTR : public SeqModifierBase {
//! Config options
using Config = mgb::cg::ComputingGraph::Options::DTRConfig;
Config* m_config;
class ModifyActionPlanner;
public:
SeqModifierForDTR(ComputingGraphImpl* owner, Config* config_g);
void modify_endpoint_vars(VarNodeArray& endpoints);
void apply_action(SeqModifyAction& action, const OprNodeArray& oprseq);
};
} // namespace cg
} // namespace mgb
#endif // MGB_ENABLE_DTR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -11,12 +11,12 @@
#include "./seq_modifier_base.h"
#if MGB_ENABLE_SUBLINEAR
#if MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR
using namespace mgb;
using namespace cg;
void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_seq) {
void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_seq, bool remove_unused_output) {
m_orig_opr_seq = &opr_seq;
m_var_storage.clear();
......@@ -76,15 +76,16 @@ void SeqModifierBase::ModifyActionPlannerBase::init_seq(const OprNodeArray& opr_
mgb_assert(!opr->output.empty());
}
// remove unused output
for (auto&& i : m_seq) {
auto&& oarr = i->output;
for (size_t j = 0; j < oarr.size();) {
if (oarr[j]->access_rec.size() == 1) {
std::swap(oarr[j], oarr.back());
oarr.pop_back();
} else
++j;
if (remove_unused_output) {
for (auto&& i : m_seq) {
auto&& oarr = i->output;
for (size_t j = 0; j < oarr.size();) {
if (oarr[j]->access_rec.size() == 1) {
std::swap(oarr[j], oarr.back());
oarr.pop_back();
} else
++j;
}
}
}
}
......@@ -105,17 +106,14 @@ bool SeqModifierBase::replace_vars(const VarNodeArray& inputs) {
OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs(
OperatorNodeBase* opr, bool recomp, size_t recomp_cnt) {
auto config = opr->config();
// update operator instance id to bybass the shallow copy's cache if
// it's a dup-opr-copying due to discarding.
// Don't update instance id by `this` pointer if it's a recomp-opr-copying
// because:
// 0) recomp-opr would be copied iff its input vars is changed
// 1) some pair of recomp-opr and dup-opr have the same inputs, params
// and config, we use instance id to differentiate them.
// update operator instance id to bybass the shallow copy's cache because
// some pair of recomp-opr and dup-opr have the same inputs, params and
// config, we use instance id to differentiate them. To be safe, we update
// instance id whatever reason is `recomp` or `dup`
config.name(opr->name() + (recomp ? ":recomp" : ":dup") + std::to_string(recomp_cnt));
config.update_instance_id(reinterpret_cast<void*>(
reinterpret_cast<size_t>(this) +
((static_cast<size_t>(recomp) + 1) << 10) * recomp_cnt));
(recomp_cnt << 1 | (recomp & 1))));
// Note: if all outputs of op were placed on the same comp_node, since its
// stream maybe changed during seq_comp_node_opt, output's comp_node has
......@@ -156,6 +154,6 @@ OperatorNodeBase* SeqModifierBase::copy_opr_from_new_inputs(
return opr_new;
}
#endif // MGB_ENABLE_SUBLINEAR
#endif // MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
......@@ -17,12 +17,11 @@
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/system.h"
#include "megbrain/utils/async_worker.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/timer.h"
#if MGB_ENABLE_SUBLINEAR
#if MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR
namespace mgb {
namespace cg {
......@@ -57,11 +56,11 @@ public:
static constexpr size_t DUPOPR_TIME =
std::numeric_limits<size_t>::max() - 1;
const SeqModifierBase* const par_modifier() {
auto& par_modifier() {
return m_par_modifier;
}
const OprNodeArray* const orig_opr_seq() {
auto& orig_opr_seq() {
return m_orig_opr_seq;
}
......@@ -94,7 +93,7 @@ public:
}
//! init m_orig_opr_seq from opr_seq, should be called first.
void init_seq(const OprNodeArray& opr_seq);
void init_seq(const OprNodeArray& opr_seq, bool remove_unused_output=true);
};
SeqModifierBase(ComputingGraphImpl* owner) : m_mem_opt(owner), m_owner_graph(owner) {}
......@@ -103,7 +102,7 @@ public:
return m_mem_opt;
}
ComputingGraphImpl* const owner_graph() {
auto& owner_graph() {
return m_owner_graph;
}
......@@ -232,6 +231,6 @@ struct SeqModifierBase::Var {
} // namespace cg
} // namespace mgb
#endif // MGB_ENABLE_SUBLINEAR
#endif // MGB_ENABLE_SUBLINEAR || MGB_ENABLE_DTR
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
\ No newline at end of file
......@@ -18,6 +18,10 @@
#include <string>
#ifndef MGB_ENABLE_DTR
#define MGB_ENABLE_DTR ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD))
#endif // MGB_ENABLE_DTR
#ifndef MGB_ENABLE_SUBLINEAR
#define MGB_ENABLE_SUBLINEAR ((!MGB_BUILD_SLIM_SERVING) && (!!MGB_HAVE_THREAD))
#endif // MGB_ENABLE_SUBLINEAR
......
......@@ -433,6 +433,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
int num_worker = sys::get_cpu_count() / 2;
} sublinear_mem_config;
//! whether to enable DTR memory optimization
bool enable_dtr_memory_opt = false;
//! Control parameter for DTR memory optimization
struct DTRConfig {
size_t eviction_threshold = 0;
size_t evictee_minimum_size = 1ULL << 20;
} dtr_config;
//! do not re-profile to select best impl algo when input shape
//! changes (use previous algo)
bool no_profiling_on_shape_change = false;
......
......@@ -172,6 +172,15 @@ TEST(TestSublinearMemory, FullConv) {
}
}
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-3);
graph->options().enable_sublinear_memory_opt = false;
graph->options().enable_dtr_memory_opt = true;
graph->options().dtr_config.eviction_threshold = 1ULL << 30;
auto func = graph->compile(out_spec);
func->execute();
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-3);
}
......@@ -238,6 +247,15 @@ TEST(TestSublinearMemory, ConcatSplit) {
}
}
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-3);
graph->options().enable_sublinear_memory_opt = false;
graph->options().enable_dtr_memory_opt = true;
graph->options().dtr_config.eviction_threshold = 1ULL << 30;
auto func = graph->compile(out_spec);
func->execute();
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-3);
}
......@@ -302,6 +320,15 @@ TEST(TestSublinearMemory, MultiOutputOpr) {
}
}
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-3);
graph->options().enable_sublinear_memory_opt = false;
graph->options().enable_dtr_memory_opt = true;
graph->options().dtr_config.eviction_threshold = 1ULL << 30;
auto func = graph->compile(out_spec);
func->execute();
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-3);
}
......@@ -365,6 +392,15 @@ TEST(TestSublinearMemory, LongChain) {
}
}
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-4);
graph->options().enable_sublinear_memory_opt = false;
graph->options().enable_dtr_memory_opt = true;
graph->options().dtr_config.eviction_threshold = 1ULL << 30;
auto func = graph->compile(out_spec);
func->execute();
for (size_t i = 0; i < grad_params_get.size(); ++i)
MGB_ASSERT_TENSOR_NEAR(grad_params_get[i], grad_params_expect[i], 1e-4);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册