提交 9415ba58 编写于 作者: M Megvii Engine Team

feat(src/core): free weight preprocessed weight

GitOrigin-RevId: 5f91acb909bdc58bfa8494c53665eaf7dfed15da
上级 7cd71c31
......@@ -576,10 +576,10 @@ VarNode& VarNode::add_flag(Flag flag) {
void VarNode::modify_flag(Flag delta, Flag new_flag) {
if (contain_flag(Flag::FLAG_FREEZED)) {
mgb_assert((delta & (
Flag::NO_MEM_RECLAIM |
Flag::NO_SYS_STATIC_MEM_ALLOC |
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta);
mgb_assert(
(delta & (Flag::NO_MEM_RECLAIM | Flag::NO_SYS_STATIC_MEM_ALLOC |
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta ||
(new_flag & Flag::MEMORY_NO_NEED));
mgb_assert(!ComputingGraphImpl::downcast(owner_graph())->
var_node_mem_manager().optimize_started(),
......
......@@ -24,6 +24,8 @@
#include "megbrain/utils/timer.h"
#include "megbrain/utils/arith_helper.h"
#include "megbrain/opr/io.h"
#include <chrono>
using namespace mgb;
......@@ -36,7 +38,6 @@ void call_mem_status_changed(cg::OperatorNodeBase* opr) {
if (cb.on_mem_status_changed.valid())
cb.on_mem_status_changed.val()();
}
} // namespace
/* ==================== StaticDeviceMemoryManager ==================== */
......@@ -393,11 +394,12 @@ bool VarNodeMemManager::alloc_var_node_mem_static() {
bool VarNodeMemManager::update_static_alloc_plan() {
// check whether unchanged
bool free_no_need_memory = free_combine_memory_no_need_var();
if (!m_owner_graph->static_infer_comp_seq_manager()
.update_static_check_shape_change() &&
!m_first_static_plan_run &&
!m_impure_mem_plan_mgr.check_need_realloc()) {
return false;
return false || free_no_need_memory;
}
if (m_first_static_plan_run)
......@@ -494,6 +496,96 @@ bool VarNodeMemManager::make_static_var_tensor_from_alloc_plan() {
return true;
}
bool VarNodeMemManager::free_combine_memory_no_need_var() {
if (!m_owner_graph->options().graph_opt.weight_preprocess ||
m_already_free_no_need_mem) {
return false;
}
bool reordered = false;
//! free no need storage
for (auto opr : *m_opr_seq) {
if (opr->try_cast_final<opr::SharedDeviceTensor>() ||
opr->try_cast_final<opr::SharedDeviceTensorWithFormat>()) {
auto opr_base = static_cast<opr::intl::SharedDeviceTensorBase*>(opr);
auto var = opr_base->output(0);
if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) &&
var->dev_tensor_valid() && !var->dev_tensor().empty()) {
//! Only the tensor share count is 1, it can be free
if (opr_base->dev_data().use_count() == 1) {
auto layout = var->layout();
var->m_dev_tensor.reset(
DeviceTensorStorage{var->comp_node()}, layout);
opr_base->free_dev_data();
mgb_log_debug(
"preprocessed weight is freed, var name = %s, "
"var layout = %s",
var->name().c_str(), layout.to_string().c_str());
}
m_already_free_no_need_mem = true;
}
}
bool memory_need_reorder = false;
if (opr->try_cast_final<opr::MultipleDeviceTensorHolder>() ||
opr->try_cast_final<opr::MultipleDeviceTensorWithFormatHolder>()) {
auto opr_base =
static_cast<opr::intl::MultipleDeviceTensorHolderBase*>(
opr);
for (size_t index = 0; index < opr_base->output().size(); index++) {
auto var = opr_base->output(index);
if (var->contain_flag(VarNode::Flag::MEMORY_NO_NEED) &&
var->dev_tensor_valid() && !var->dev_tensor().empty()) {
//! Only the tensor share count is 1, it can be free
if (opr_base->values()[index].use_count() == 1) {
auto layout = var->layout();
var->m_dev_tensor.reset(
DeviceTensorStorage{var->comp_node()}, layout);
opr_base->mutable_values()[index]->reset(
DeviceTensorStorage{var->comp_node()}, layout);
memory_need_reorder = true;
mgb_log_debug(
"preprocessed weight is freed, var name "
"= %s, var layout = %s",
var->name().c_str(),
layout.to_string().c_str());
}
m_already_free_no_need_mem = true;
}
}
}
//! recorder the other needed outputs, because they share the
//! same chunk of mem in device with no needed var, see
//! BatchedDeviceValueLoader
if (memory_need_reorder) {
auto opr_base =
static_cast<opr::intl::MultipleDeviceTensorHolderBase*>(
opr);
auto comp_node = opr_base->output(0)->comp_node();
bool is_device_opr =
comp_node.mem_node() != CompNode::default_cpu().mem_node();
if (memory_need_reorder && is_device_opr) {
for (size_t index = 0; index < opr_base->output().size();
index++) {
auto var = opr_base->output(index);
if (!var->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) {
DeviceTensorStorage storage(var->comp_node());
size_t size = var->layout().span().dist_byte();
storage.ensure_size(size);
storage.copy_from(var->m_dev_tensor.storage(), size);
var->m_dev_tensor.reset(storage, var->layout());
opr_base->mutable_values()[index]->reset(storage,
var->layout());
reordered = true;
}
}
//! sync to make sure memcopy is finished
comp_node.sync();
}
}
}
return reordered;
}
void VarNodeMemManager::init_dynamic_alloc_opr_info() {
mgb_assert(m_first_static_plan_run);
m_need_post_exec_action_vars.clear();
......
......@@ -173,6 +173,14 @@ class VarNodeMemManager {
*/
bool alloc_var_node_mem_static();
/*!
* \brief free the memory of var with MEMORY_NO_NEED flag
*
* \return whether memory of MEMORY_NO_NEED var or related other var
* memory changed
*/
bool free_combine_memory_no_need_var();
/*!
* \brief initialize static memory allocation plan
*
......@@ -407,7 +415,8 @@ class VarNodeMemManager {
bool check_need_realloc();
};
bool m_first_static_plan_run = true, m_optimize_started = false;
bool m_first_static_plan_run = true, m_optimize_started = false,
m_already_free_no_need_mem = false;
ComputingGraphImpl *m_owner_graph;
ThinHashMap<VarNode*, VarNodeMemTrait> m_node_mem_trait;
NullableHashMap<OperatorNodeBase*, DynamicAllocOprInfo>
......
......@@ -449,7 +449,11 @@ DEF(resize, &)(const TensorShape& shape) {
}
DEF(reset, &)(TensorStorage storage, const TensorLayout &layout) {
mgb_assert(!layout.ndim || storage.valid_span(layout.span()));
//! The storage to be reset is either satisfy the layout or empty.
//! Empty storage is used after weight preprocess for saving memory and
//! checking layout when running
mgb_assert(!layout.ndim || storage.valid_span(layout.span()) ||
storage.empty());
m_storage = std::move(storage);
m_layout = layout;
return static_cast<ChainReturnType&>(*this);
......
......@@ -98,7 +98,8 @@ struct GraphCommonOptimizeOptions {
//! whether to enable fast-run profiled winograd opr replace
bool weight_winograd_transform = false;
//! whether to enable weight preprocess, if enabled it may use more
//! memory, default disable now
//! memory, default disable now, when weight preprocess is enabled, the
//! input shape should no change
bool weight_preprocess = false;
enum LayoutTransform : uint32_t {
DEFAULT,
......
......@@ -589,7 +589,7 @@ class VarNode final: public GraphNodeBase {
friend class imperative::ProxyGraph;
};
enum class VarNode::Flag: uint32_t {
enum class VarNode::Flag : uint32_t {
//! do not allocate memory by the system allocator even if shape could be
//! inferred
NO_SYS_MEM_ALLOC = 1 << 0,
......@@ -667,6 +667,12 @@ enum class VarNode::Flag: uint32_t {
* after FLAG_FREEZED is present.
*/
FLAG_FREEZED = 1 << 10,
/*!
* this flag indicates that data of this var has been processed and no need
* later, it can be freed, this is used in weight preprocess for memory save
*/
MEMORY_NO_NEED = 1 << 11,
};
MGB_DEF_ENUM_CLASS_BIT_OPR(VarNode::Flag)
......
......@@ -1920,4 +1920,236 @@ TEST(TestGraph, NaiveRecord2NCHW44) {
func->execute().wait();
}
namespace {
template <typename DnnOp, typename... Args>
typename DnnOp::Algorithm* try_find_any_weight_preprocess_algo(
DnnOp* dnn_op, const char* mgb_info, Maybe<bool>& found,
Args&& ...args) {
if (found.valid()) {
if (found.val()) {
return dnn_op->execution_policy().algorithm;
} else {
return nullptr;
}
}
for (auto&& algo : dnn_op->get_all_algorithms(
std::forward<Args>(args)...)) {
dnn_op->execution_policy().algorithm = algo;
auto layouts = dnn_op->deduce_preprocessed_filter_layout(
std::forward<Args>(args)...);
if (layouts.empty()) continue;
bool valid = false;
for (auto&& l: layouts) {
if (!l.is_empty()) {
valid = true;
break;
}
}
if (valid) {
found.emplace(true);
return algo;
}
}
found.emplace(false);
mgb_log_warn("Can't find weight preprocess algo for op %s", mgb_info);
return nullptr;
}
void test_free_memory_in_weight_preprocess(int record_level, CompNode cn) {
HostTensorGenerator<> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt.weight_preprocess = true;
graph->options().comp_node_seq_record_level = record_level;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn))
.rename(name);
};
auto x = mkvar("x", {1, 32, 16, 16});
// ConvBias test dense
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 0;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {32, 32, 1, 1}), b1 = mkcvar("b1", {1, 32, 1, 1});
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias);
Maybe<bool> wp1, wp2;
conv1.node()->owner_opr()->cast_final_safe<opr::ConvBias>()
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) {
return try_find_any_weight_preprocess_algo(
opr->cast_final_safe<opr::ConvBias>().megdnn_opr(),
opr->cname(), wp1,
opr->input(0)->layout(), opr->input(1)->layout(),
opr->input(2)->layout(), TensorLayout{},
opr->output(0)->layout());
});
// Convolution
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 0;
param_conv.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w2 = mkcvar("w2", {32, 32, 1, 1});
auto y = opr::Convolution::make(conv1, w2, param_conv);
y.node()->owner_opr()->cast_final_safe<opr::Convolution>()
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) {
return try_find_any_weight_preprocess_algo(
opr->cast_final_safe<opr::Convolution>().megdnn_opr(),
opr->cname(), wp2,
opr->input(0)->layout(), opr->input(1)->layout(),
opr->output(0)->layout());
});
HostTensorND host_y;
auto func =graph->compile({make_callback_copy(y, host_y)});
//!flag the no need memory of var
func->execute();
//!free the no need memory of var
func->execute();
auto check = [&](SymbolVar v) {
ASSERT_TRUE(v.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED));
ASSERT_TRUE(v.node()->dev_tensor().empty());
ASSERT_TRUE(v.node()->owner_opr()
->cast_final_safe<opr::SharedDeviceTensor>()
.get_dev_tensor()
.empty());
};
ASSERT_TRUE(wp1.valid() && wp2.valid());
if (wp1.val()) {
check(w1);
}
if (wp2.val()) {
check(w2);
}
}
} // anonymous namespace
TEST(TestGraph, FreeMemoryInWeightPreprocess) {
test_free_memory_in_weight_preprocess(0, CompNode::load("xpu0"));
}
TEST(TestGraph, RecordFreeMemoryInWeightPreprocess) {
test_free_memory_in_weight_preprocess(1, CompNode::load("cpu0"));
}
namespace {
MGB_DEFINE_OPR_CLASS(HostValueReader, cg::SingleCNOutshapePureByInshapeOprBase) // {
void scn_do_execute() override {
auto&& hv = owner_graph()->static_infer_manager().infer_value(input(0));
MGB_MARK_USED_VAR(hv);
}
NodeProp* do_make_node_prop() const override {
auto ret = Super::do_make_node_prop();
ret->dep_map()[input(0)] = NodeProp::DepType::HOST_VALUE;
return ret;
}
void get_output_var_shape(
const TensorShapeArray &,
TensorShapeArray &out_shape) const override {
out_shape.at(0) = {};
}
public:
HostValueReader(VarNode* inp)
: Super{inp->owner_graph(), {}, "host_value_reader", {inp}} {
add_input({inp});
using F = VarNode::Flag;
add_output(None)
->add_flag(F::ALLOW_EMPTY_SHAPE)
.add_flag(F::VOLATILE_CONTENT);
}
static SymbolVar make(SymbolVar inp) {
return inp.node()->owner_graph()->insert_opr(
std::make_unique<HostValueReader>(inp.node()))->output(0);
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(HostValueReader);
}
TEST(TestGraph, FreeMemoryInWeightPreprocessWithValueInfer) {
HostTensorGenerator<> gen;
CompNode cn = CompNode::load("xpux");
auto graph = ComputingGraph::make();
graph->options().graph_opt.weight_preprocess = true;
graph->options().var_sanity_check_first_run = false;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn))
.rename(name);
};
auto x = mkvar("x", {1, 32, 16, 16});
auto w = mkcvar("w", {32, 32, 1, 1});
auto y = opr::Convolution::make(x, w);
Maybe<bool> found;
y.node()->owner_opr()->cast_final_safe<opr::Convolution>()
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) {
return try_find_any_weight_preprocess_algo(
opr->cast_final_safe<opr::Convolution>().megdnn_opr(),
opr->cname(), found,
opr->input(0)->layout(), opr->input(1)->layout(),
opr->output(0)->layout());
});
auto reader = HostValueReader::make(w);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y), {reader, {}}});
func->execute();
// FIXME: failed on second execution due to requiring host value of the empty
// tensor which was freed in weight preprocess
func->execute();
ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED));
ASSERT_FALSE(w.node()->dev_tensor().empty());
ASSERT_FALSE(w.node()->owner_opr()
->cast_final_safe<opr::SharedDeviceTensor>()
.get_dev_tensor()
.empty());
}
TEST(TestGraph, FreeMemoryInWeightPreprocessWithMultiReader) {
HostTensorGenerator<> gen;
CompNode cn = CompNode::load("xpux");
auto graph = ComputingGraph::make();
graph->options().graph_opt.weight_preprocess = true;
graph->options().var_sanity_check_first_run = false;
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make_const(*graph, *gen(shp, cn))
.rename(name);
};
auto x = mkvar("x", {1, 32, 16, 16});
auto w = mkcvar("w", {32, 32, 1, 1});
auto y = opr::Convolution::make(x, w);
Maybe<bool> found;
y.node()->owner_opr()->cast_final_safe<opr::Convolution>()
.setup_algo_chooser([&](const cg::OperatorNodeBase* opr) {
return try_find_any_weight_preprocess_algo(
opr->cast_final_safe<opr::Convolution>().megdnn_opr(),
opr->cname(), found,
opr->input(0)->layout(), opr->input(1)->layout(),
opr->output(0)->layout());
});
auto y1 = w * 2 + 1;
HostTensorND host_y, host_y1;
auto func = graph->compile({
make_callback_copy(y, host_y), make_callback_copy(y1, host_y1)});
func->execute();
// FIXME: failed on second execution due to calculate expression
// (w * 2 + 1) with empty tensor
func->execute();
ASSERT_FALSE(w.node()->contain_flag(VarNode::Flag::MEMORY_NO_NEED));
ASSERT_FALSE(w.node()->dev_tensor().empty());
ASSERT_FALSE(w.node()->owner_opr()
->cast_final_safe<opr::SharedDeviceTensor>()
.get_dev_tensor()
.empty());
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -138,39 +138,36 @@ public:
void mixin::WeightPreprocessExecutor::mixin_update_preprocessed_filter(
cg::OperatorNodeBase& opr) {
if (!mixin_allow_weight_preprocess(opr))
if (!mixin_allow_weight_preprocess(opr)) {
return;
}
auto new_layout = deduce_preprocessed_filter_layout();
size_t new_size = new_layout.size();
//! No preprocess layout means no need weight preprocess
if (new_layout.empty()) {
// Weight preprocess was needed before, but no longer needed.
if (m_preprocessed_filter) {
m_preprocessed_filter.reset();
m_filter_storage.clear();
return;
}
//! all layouts arm empty means no need weight preprocess
bool layout_valid = false;
for (auto&& layout : new_layout) {
if (!layout.is_empty()) {
layout_valid = true;
}
}
if (!layout_valid) {
return;
}
bool should_update = false;
size_t new_size = new_layout.size();
if (!m_preprocessed_filter ||
m_preprocessed_filter->tensors.size() != new_size) {
should_update = true;
} else {
if (m_preprocessed_filter) {
for (size_t i = 0; i < new_size; i++) {
if (!new_layout[i].eq_layout(
m_preprocessed_filter->tensors[i].layout)) {
should_update = true;
break;
}
mgb_assert(new_layout[i].eq_layout(
m_preprocessed_filter->tensors[i].layout),
"weight preprocess layout changed, please keep input "
"shape unchanged when weight preprocess is enabled");
}
}
if (!should_update)
return;
if (!m_preprocessed_filter) {
m_preprocessed_filter.reset(new PreprocessedFilter{});
}
m_preprocessed_filter.reset(new PreprocessedFilter{});
m_preprocessed_filter->tensors.resize(new_size);
m_filter_storage.resize(new_size);
m_preprocessed_filter->algorithm_id = nullptr;
......@@ -327,6 +324,14 @@ void ConvolutionForward::scn_do_execute_preprocess() {
input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
output(0)->layout(), preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
//! Flag the input(1) no use later, which can be freed when no other
//! var depend on its dev_value, host_value and shape.
auto receiver_info =
input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1));
if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 &&
receiver_info.shape == 0) {
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED);
}
}
/* ==================== ConvolutionBackwardData ==================== */
......@@ -959,6 +964,14 @@ void ConvBiasForward::scn_do_execute_preprocess() {
input(0)->layout(), input(1)->dev_tensor().as_megdnn(), bias_layout,
z_layout, output(0)->layout(), preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
//! Flag the input(1) no use later, which can be freed when no other
//! var depend on its dev_value, host_value and shape.
auto receiver_info =
input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1));
if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 &&
receiver_info.shape == 0) {
input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED);
}
}
/* ===================== LocalShareForward ==================== */
......
......@@ -142,8 +142,10 @@ void intl::DeviceTensorHolder::add_output(DType dtype) {
}
void intl::DeviceTensorHolder::record_execute_deps(ExecDependencyArray& deps) {
deps.emplace_back(
std::make_unique<DevValueExecDep>(get_dev_tensor().storage()));
if (!output(0)->contain_flag(VarNode::Flag::MEMORY_NO_NEED)) {
deps.emplace_back(
std::make_unique<DevValueExecDep>(get_dev_tensor().storage()));
}
}
/* ===================== Host2DeviceCopy ===================== */
......@@ -801,14 +803,19 @@ class intl::MultipleDeviceTensorHolderBase::DevValuesExecDep final
SmallVector<DeviceTensorStorage> m_vals;
public:
explicit DevValuesExecDep(const ValueArray& vals) {
for (auto&& val : vals) {
m_vals.emplace_back(std::move(val->storage()));
explicit DevValuesExecDep(const ValueArray& vals,
MultipleDeviceTensorHolderBase* opr) {
mgb_assert(vals.size() == opr->output().size(),
"the output value size is diff from output var size");
for (size_t index = 0; index < vals.size(); index++) {
if (!opr->output(index)->contain_flag(
VarNode::Flag::MEMORY_NO_NEED)) {
m_vals.emplace_back(std::move(vals[index]->storage()));
}
}
}
};
intl::MultipleDeviceTensorHolderBase::MultipleDeviceTensorHolderBase(
ComputingGraph& graph, ValueArray values,
const OperatorNodeConfig& config)
......@@ -887,8 +894,7 @@ intl::MultipleDeviceTensorHolderBase::do_make_node_prop() const {
void intl::MultipleDeviceTensorHolderBase::record_execute_deps(
ExecDependencyArray& deps) {
deps.emplace_back(
std::make_unique<DevValuesExecDep>(values()));
deps.emplace_back(std::make_unique<DevValuesExecDep>(values(), this));
}
/* ===================== MultipleDeviceTensorHolder ===================== */
......
......@@ -173,9 +173,15 @@ size_t AlgoChooser<Opr>::setup_algo(const ConvTensorLayouts& layouts,
return 0;
}
ImplAlgo algo = nullptr;
ExeContext ctx(layouts, megdnn_opr, mgb_opr, allow_weight_preprocess);
auto algo = get_algo(ctx);
if (auto algo_choose_hook = mgb_opr->algo_chooser()) {
algo = algo_choose_hook(mgb_opr);
}
if (!algo) {
algo = get_algo(ctx);
}
size_t workspace = ctx.get_workspace_size_bytes(algo);
mgb_log_debug(
"%s: tensor layouts(%s %s, %s %s) -> (%s %s): algo=%s "
......@@ -360,16 +366,29 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
if (!m_allow_weight_preprocess)
return;
auto opr = _(m_megdnn_opr);
auto layout = APPLY(opr->deduce_preprocessed_filter_layout(args...),
m_layouts);
if (layout.empty())
auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...),
m_layouts);
//! No preprocess layout means no need weight preprocess
if (layouts.empty()) {
return;
}
//! all layouts arm empty means no need weight preprocess
bool layout_valid = false;
for (auto&& layout : layouts) {
if (!layout.is_empty()) {
layout_valid = true;
}
}
if (!layout_valid) {
return;
}
result = PreprocessFilter<Opr>{};
auto& res = result.val();
res.algorithm_id = nullptr;
res.tensors.resize(layout.size());
for (size_t i = 0; i < layout.size(); i++) {
res.tensors[i] = megdnn::TensorND(nullptr, layout[i]);
res.tensors.resize(layouts.size());
for (size_t i = 0; i < layouts.size(); i++) {
res.tensors[i] = megdnn::TensorND(nullptr, layouts[i]);
}
});
return result;
......
......@@ -25,6 +25,9 @@ namespace mixin {
class Convolution {
public:
using ExecutionPolicy = megdnn::param::ExecutionPolicy;
using Algorithm = megdnn::detail::Algorithm;
using AlgoChooserHook =
std::function<Algorithm*(const OperatorNodeBase*)>;
const ExecutionPolicy& execution_policy() const {
if (!m_policy_accessed) {
......@@ -55,6 +58,16 @@ class Convolution {
virtual std::pair<const void*, size_t> param_blob() const = 0;
/*!
* \brief register a hook to implement custom algo chooser
*/
void setup_algo_chooser(AlgoChooserHook&& func) {
m_algo_chooser = func;
}
AlgoChooserHook algo_chooser() const {
return m_algo_chooser;
}
protected:
~Convolution();
......@@ -63,6 +76,8 @@ class Convolution {
std::unique_ptr<AlgoChooserProfileCache> m_profile_cache;
AlgoChooserHook m_algo_chooser;
virtual void init_profile_cache() = 0;
//! init output desc for conv backward data oprs; it handles both grad
......
......@@ -99,6 +99,11 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
return *m_dev_data;
}
void free_dev_data() {
m_dev_data->reset(DeviceTensorStorage{m_dev_data->comp_node()},
m_dev_data->layout());
}
const std::shared_ptr<DeviceTensorND>& dev_data() const {
return m_dev_data;
}
......@@ -122,6 +127,10 @@ public:
const OperatorNodeConfig& config);
const ValueArray& values() const { return m_values; }
ValueArray& mutable_values() {
return m_values;
}
protected:
ValueArray m_values;
......@@ -292,7 +301,7 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // {
static SymbolVar make_const(ComputingGraph& graph,
const HostTensorND& value,
const OperatorNodeConfig& config = {}) {
return make(graph, value, false, config);
return make(graph, value, true, config);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册