提交 8b764934 编写于 作者: M Megvii Engine Team

feat(lite): lite support output var no copy option

GitOrigin-RevId: 5b9488cb93fecb70f0ca0018edde29a6039f5510
上级 7642f66d
...@@ -93,6 +93,7 @@ struct LITE_API Options { ...@@ -93,6 +93,7 @@ struct LITE_API Options {
bool const_shape = false; bool const_shape = false;
bool force_dynamic_alloc = false; bool force_dynamic_alloc = false;
bool force_output_dynamic_alloc = false; bool force_output_dynamic_alloc = false;
bool force_output_use_user_specified_memory = false;
bool no_profiling_on_shape_change = false; bool no_profiling_on_shape_change = false;
uint8_t jit_level = 0; uint8_t jit_level = 0;
uint8_t comp_node_seq_record_level = 0; uint8_t comp_node_seq_record_level = 0;
......
...@@ -83,6 +83,7 @@ typedef struct Options { ...@@ -83,6 +83,7 @@ typedef struct Options {
int const_shape; int const_shape;
int force_dynamic_alloc; int force_dynamic_alloc;
int force_output_dynamic_alloc; int force_output_dynamic_alloc;
int force_output_use_user_specified_memory;
int no_profiling_on_shape_change; int no_profiling_on_shape_change;
int jit_level; int jit_level;
int comp_node_seq_record_level; int comp_node_seq_record_level;
......
...@@ -29,6 +29,7 @@ const LiteOptions default_option = { ...@@ -29,6 +29,7 @@ const LiteOptions default_option = {
.const_shape = false, .const_shape = false,
.force_dynamic_alloc = false, .force_dynamic_alloc = false,
.force_output_dynamic_alloc = false, .force_output_dynamic_alloc = false,
.force_output_use_user_specified_memory = false,
.no_profiling_on_shape_change = false, .no_profiling_on_shape_change = false,
.jit_level = 0, .jit_level = 0,
.comp_node_seq_record_level = 0, .comp_node_seq_record_level = 0,
...@@ -122,7 +123,9 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) { ...@@ -122,7 +123,9 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) {
lite_config.options.var_sanity_check_first_run = lite_config.options.var_sanity_check_first_run =
c_config.options.var_sanity_check_first_run; c_config.options.var_sanity_check_first_run;
lite_config.options.const_shape = c_config.options.const_shape; lite_config.options.const_shape = c_config.options.const_shape;
lite_config.options.force_dynamic_alloc = c_config.options.const_shape; lite_config.options.force_dynamic_alloc = c_config.options.force_dynamic_alloc;
lite_config.options.force_output_use_user_specified_memory =
c_config.options.force_output_use_user_specified_memory;
lite_config.options.force_output_dynamic_alloc = lite_config.options.force_output_dynamic_alloc =
c_config.options.force_output_dynamic_alloc; c_config.options.force_output_dynamic_alloc;
lite_config.options.no_profiling_on_shape_change = lite_config.options.no_profiling_on_shape_change =
......
...@@ -29,6 +29,7 @@ class LiteOptions(Structure): ...@@ -29,6 +29,7 @@ class LiteOptions(Structure):
("const_shape", c_int), ("const_shape", c_int),
("force_dynamic_alloc", c_int), ("force_dynamic_alloc", c_int),
("force_output_dynamic_alloc", c_int), ("force_output_dynamic_alloc", c_int),
("force_output_use_user_specified_memory", c_int),
("no_profiling_on_shape_change", c_int), ("no_profiling_on_shape_change", c_int),
("jit_level", c_int), ("jit_level", c_int),
("comp_node_seq_record_level", c_int), ("comp_node_seq_record_level", c_int),
...@@ -52,6 +53,7 @@ class LiteOptions(Structure): ...@@ -52,6 +53,7 @@ class LiteOptions(Structure):
self.const_shape = False self.const_shape = False
self.force_dynamic_alloc = False self.force_dynamic_alloc = False
self.force_output_dynamic_alloc = False self.force_output_dynamic_alloc = False
self.force_output_use_user_specified_memory = False
self.no_profiling_on_shape_change = False self.no_profiling_on_shape_change = False
self.jit_level = 0 self.jit_level = 0
self.comp_node_seq_record_level = 0 self.comp_node_seq_record_level = 0
...@@ -67,6 +69,7 @@ class LiteOptions(Structure): ...@@ -67,6 +69,7 @@ class LiteOptions(Structure):
"const_shape": bool(self.const_shape), "const_shape": bool(self.const_shape),
"force_dynamic_alloc": bool(self.force_dynamic_alloc), "force_dynamic_alloc": bool(self.force_dynamic_alloc),
"force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc), "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc),
"force_output_nocopy": bool(self.force_output_nocopy),
"no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change), "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change),
"jit_level": self.jit_level, "jit_level": self.jit_level,
"comp_node_seq_record_level": self.comp_node_seq_record_level, "comp_node_seq_record_level": self.comp_node_seq_record_level,
......
...@@ -84,6 +84,9 @@ void NetworkImplDft::application_config() { ...@@ -84,6 +84,9 @@ void NetworkImplDft::application_config() {
m_load_config.const_var_shape = m_user_config->options.const_shape; m_load_config.const_var_shape = m_user_config->options.const_shape;
ConfigOption(force_dynamic_alloc, force_dynamic_alloc); ConfigOption(force_dynamic_alloc, force_dynamic_alloc);
ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc); ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc);
ConfigOption(
force_output_use_user_specified_memory,
force_output_use_user_specified_memory);
ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change); ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change);
LITE_ASSERT( LITE_ASSERT(
m_user_config->options.jit_level == 0 || m_user_config->options.jit_level == 0 ||
...@@ -250,7 +253,13 @@ void NetworkImplDft::make_output_spec() { ...@@ -250,7 +253,13 @@ void NetworkImplDft::make_output_spec() {
} }
} }
}; };
//! if write to user-specified memory, the CallbackCaller must be nullptr.
if (m_user_config->options.force_output_use_user_specified_memory ||
m_user_config->options.force_output_dynamic_alloc) {
m_output_spec.emplace_back(load_out, nullptr);
} else {
m_output_spec.emplace_back(load_out, std::move(cb)); m_output_spec.emplace_back(load_out, std::move(cb));
}
} else { } else {
LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str())); LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str()));
} }
...@@ -444,8 +453,7 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) { ...@@ -444,8 +453,7 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
} }
} }
void NetworkImplDft::try_infer_tensor_layout( void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var) {
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager(); auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
auto infer_trait = var.node()->get_static_infer_trait(); auto infer_trait = var.node()->get_static_infer_trait();
if (std::get<0>(infer_trait)) { if (std::get<0>(infer_trait)) {
...@@ -455,9 +463,13 @@ void NetworkImplDft::try_infer_tensor_layout( ...@@ -455,9 +463,13 @@ void NetworkImplDft::try_infer_tensor_layout(
"Lite infer output shape failed, maybe the model is " "Lite infer output shape failed, maybe the model is "
"dynamic " "dynamic "
"shape.\n"); "shape.\n");
LITE_ASSERT(
!m_user_config->options.force_output_use_user_specified_memory,
"force_output_use_user_specified_memory can't be used when output "
"shape can't be derived.");
return; return;
} }
Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()}); Layout layout = to_lite_layout(TensorLayout{*shape, var.dtype()});
tensor->set_layout(layout); tensor->set_layout(layout);
} }
} }
...@@ -559,8 +571,7 @@ void NetworkImplDft::update_output() { ...@@ -559,8 +571,7 @@ void NetworkImplDft::update_output() {
out_it != m_network_io->outputs.end();) { out_it != m_network_io->outputs.end();) {
if (std::find_if( if (std::find_if(
m_load_result.output_var_list.begin(), m_load_result.output_var_list.begin(),
m_load_result.output_var_list.end(), m_load_result.output_var_list.end(), [out_it](const SymbolVar var) {
[out_it](const mgb::SymbolVar var) {
return var.node()->name() == out_it->name; return var.node()->name() == out_it->name;
}) == m_load_result.output_var_list.end()) { }) == m_load_result.output_var_list.end()) {
LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str()); LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str());
...@@ -584,7 +595,7 @@ void NetworkImplDft::update_output() { ...@@ -584,7 +595,7 @@ void NetworkImplDft::update_output() {
out_it->lite_tensor = out_it->lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device_type); std::make_shared<Tensor>(device_id, stream_id, device_type);
} }
mgb::SymbolVar var; SymbolVar var;
for (auto&& out_var : m_load_result.output_var_list) { for (auto&& out_var : m_load_result.output_var_list) {
if (out_var.node()->name() == out_it->name) { if (out_var.node()->name() == out_it->name) {
var = out_var; var = out_var;
...@@ -592,10 +603,12 @@ void NetworkImplDft::update_output() { ...@@ -592,10 +603,12 @@ void NetworkImplDft::update_output() {
} }
} }
try_infer_tensor_layout(out_it->lite_tensor, var); try_infer_tensor_layout(out_it->lite_tensor, var);
output_tensor_copy_optimize(var, out_it->lite_tensor);
} }
//! user not set, use default output //! user not set, use default output
} else { } else {
for (auto&& out : m_load_result.output_var_list) { for (auto&& out : m_load_result.output_var_list) {
std::shared_ptr<Tensor> lite_tensor = nullptr;
auto it = std::find_if( auto it = std::find_if(
m_network_io->outputs.begin(), m_network_io->outputs.end(), m_network_io->outputs.begin(), m_network_io->outputs.end(),
[&out](const IOInner io) { return io.name == out.node()->name(); }); [&out](const IOInner io) { return io.name == out.node()->name(); });
...@@ -608,6 +621,7 @@ void NetworkImplDft::update_output() { ...@@ -608,6 +621,7 @@ void NetworkImplDft::update_output() {
std::make_shared<Tensor>(device_id, stream_id, device_type); std::make_shared<Tensor>(device_id, stream_id, device_type);
} }
try_infer_tensor_layout(it->lite_tensor, out); try_infer_tensor_layout(it->lite_tensor, out);
lite_tensor = it->lite_tensor;
} else { } else {
IOInner output; IOInner output;
output.name = out.node()->name(); output.name = out.node()->name();
...@@ -615,8 +629,44 @@ void NetworkImplDft::update_output() { ...@@ -615,8 +629,44 @@ void NetworkImplDft::update_output() {
device_id, stream_id, device_type, true); device_id, stream_id, device_type, true);
m_network_io->outputs.push_back({output}); m_network_io->outputs.push_back({output});
try_infer_tensor_layout(output.lite_tensor, out); try_infer_tensor_layout(output.lite_tensor, out);
lite_tensor = output.lite_tensor;
} }
output_tensor_copy_optimize(out, lite_tensor);
}
}
}
void NetworkImplDft::output_tensor_copy_optimize(
Var var, std::shared_ptr<Tensor> tensor) {
LITE_ASSERT(
!(m_user_config->options.force_output_use_user_specified_memory &&
m_user_config->options.force_output_dynamic_alloc),
"Can't set force_output_use_user_specified_memory and "
"force_output_dynamic_alloc at the same time.");
if (m_user_config->options.force_output_use_user_specified_memory) {
TensorHelper::implement(tensor)
->cast_final_safe<TensorImplDft>()
.set_reset_callback([var](TensorImplDft* dft_tensor) {
dft_tensor->device_share_host_memory();
auto dv = dft_tensor->dev_tensor().get();
dv->comp_node(var.node()->comp_node(), true);
var.node()->init_mem_plan(dv);
var.node()->reset_dev_tensor_from_tensor(*dv);
});
}
if (m_user_config->options.force_output_dynamic_alloc) {
TensorHelper::implement(tensor)
->cast_final_safe<TensorImplDft>()
.set_get_memory_callback([var](TensorImplDft* dft_tensor) {
if (dft_tensor->is_host()) {
auto host_tensor = dft_tensor->m_host_tensor;
*host_tensor =
HostTensorND::make_proxy(var.node()->dev_tensor());
} else {
auto dev_tensor = dft_tensor->m_dev_tensor;
*dev_tensor = var.node()->dev_tensor();
} }
});
} }
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "lite_build_config.h" #include "lite_build_config.h"
#include "megbrain/graph.h"
#if LITE_BUILD_WITH_MGE #if LITE_BUILD_WITH_MGE
#include "lite/network.h" #include "lite/network.h"
...@@ -41,6 +42,7 @@ class NetworkImplDft final : public Network::NetworkImplBase { ...@@ -41,6 +42,7 @@ class NetworkImplDft final : public Network::NetworkImplBase {
public: public:
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); } NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); }
using S = megdnn::param::ExecutionPolicy::Strategy; using S = megdnn::param::ExecutionPolicy::Strategy;
using Var = mgb::cg::SymbolVar;
//! set the config of the network, include: //! set the config of the network, include:
//! the inference device //! the inference device
//! the other inference options, such as record_level, weight_preprocess... //! the other inference options, such as record_level, weight_preprocess...
...@@ -207,8 +209,10 @@ private: ...@@ -207,8 +209,10 @@ private:
void compile_graph(); void compile_graph();
//! try to infer output tensor layout //! try to infer output tensor layout
void try_infer_tensor_layout( void try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var);
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var);
//! optimized output tensor copy
void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor);
private: private:
bool m_async = false; bool m_async = false;
......
...@@ -149,6 +149,9 @@ Layout TensorImplDft::get_layout() const { ...@@ -149,6 +149,9 @@ Layout TensorImplDft::get_layout() const {
} }
void* TensorImplDft::get_memory_ptr() const { void* TensorImplDft::get_memory_ptr() const {
if (m_get_memory_callback) {
m_get_memory_callback(const_cast<TensorImplDft*>(this));
}
if (is_host()) { if (is_host()) {
return static_cast<void*>(m_host_tensor->raw_ptr()); return static_cast<void*>(m_host_tensor->raw_ptr());
} else { } else {
...@@ -157,6 +160,9 @@ void* TensorImplDft::get_memory_ptr() const { ...@@ -157,6 +160,9 @@ void* TensorImplDft::get_memory_ptr() const {
} }
void* TensorImplDft::get_memory_ptr(const std::vector<size_t>& idx) const { void* TensorImplDft::get_memory_ptr(const std::vector<size_t>& idx) const {
if (m_get_memory_callback) {
m_get_memory_callback(const_cast<TensorImplDft*>(this));
}
if (is_host()) { if (is_host()) {
auto elemsize_log = m_host_tensor->layout().dtype.size_log(); auto elemsize_log = m_host_tensor->layout().dtype.size_log();
switch (elemsize_log) { switch (elemsize_log) {
...@@ -317,6 +323,9 @@ void TensorImplDft::reset(void* prepared_data) { ...@@ -317,6 +323,9 @@ void TensorImplDft::reset(void* prepared_data) {
storage.reset(cn, size, raw_storage); storage.reset(cn, size, raw_storage);
m_dev_tensor->reset(storage, mge_layout); m_dev_tensor->reset(storage, mge_layout);
} }
if (m_reset_callback) {
m_reset_callback(this);
}
} }
void TensorImplDft::reset(void* prepared_data, const Layout& layout) { void TensorImplDft::reset(void* prepared_data, const Layout& layout) {
...@@ -430,6 +439,34 @@ void TensorImplDft::copy_from_mge_tensor(const mgb::DeviceTensorND& dv) { ...@@ -430,6 +439,34 @@ void TensorImplDft::copy_from_mge_tensor(const mgb::DeviceTensorND& dv) {
} }
} }
void TensorImplDft::set_reset_callback(const std::function<void(TensorImplDft*)>& cb) {
m_reset_callback = cb;
}
void TensorImplDft::set_get_memory_callback(
const std::function<void(TensorImplDft*)>& cb) {
m_get_memory_callback = cb;
}
void TensorImplDft::device_share_host_memory() {
if (is_host()) {
if (!m_dev_tensor) {
m_dev_tensor = std::make_shared<mgb::DeviceTensorND>(
m_host_tensor->comp_node(), m_host_tensor->layout());
}
if (m_host_tensor->raw_ptr() != m_dev_tensor->raw_ptr()) {
auto raw_storage = std::shared_ptr<mgb::dt_byte>(
m_host_tensor->raw_ptr(), [](void*) {});
auto cn = m_host_tensor->comp_node();
auto mge_layout = m_host_tensor->layout();
size_t size = mge_layout.span().dist_byte();
mgb::DeviceTensorStorage storage;
storage.reset(cn, size, raw_storage);
m_dev_tensor->reset(storage, mge_layout);
}
}
}
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -97,11 +97,22 @@ public: ...@@ -97,11 +97,22 @@ public:
//! get host tensor //! get host tensor
std::shared_ptr<mgb::HostTensorND> host_tensor() const { return m_host_tensor; } std::shared_ptr<mgb::HostTensorND> host_tensor() const { return m_host_tensor; }
//! get device tensor //! get device tensor
std::shared_ptr<mgb::DeviceTensorND> dev_tensor() const { return m_dev_tensor; } std::shared_ptr<mgb::DeviceTensorND> dev_tensor() const { return m_dev_tensor; }
//! copy from mgb tensor //! copy from mgb tensor
void copy_from_mge_tensor(const mgb::DeviceTensorND& dv); void copy_from_mge_tensor(const mgb::DeviceTensorND& dv);
//! set tensor reset callback
void set_reset_callback(const std::function<void(TensorImplDft*)>& cb);
//! set tensor get memory callback
void set_get_memory_callback(const std::function<void(TensorImplDft*)>& cb);
//! shared the same memory with host and device tensor
void device_share_host_memory();
public: public:
friend class NetworkImplDft; friend class NetworkImplDft;
...@@ -115,6 +126,8 @@ private: ...@@ -115,6 +126,8 @@ private:
void set_mge_tensor_compnode(const mgb::CompNode& comp_node); void set_mge_tensor_compnode(const mgb::CompNode& comp_node);
private: private:
std::function<void(TensorImplDft*)> m_get_memory_callback;
std::function<void(TensorImplDft*)> m_reset_callback;
std::shared_ptr<mgb::HostTensorND> m_host_tensor; std::shared_ptr<mgb::HostTensorND> m_host_tensor;
std::shared_ptr<mgb::DeviceTensorND> m_dev_tensor; std::shared_ptr<mgb::DeviceTensorND> m_dev_tensor;
}; };
......
...@@ -153,6 +153,10 @@ std::shared_ptr<Tensor> Network::get_output_tensor(size_t index) { ...@@ -153,6 +153,10 @@ std::shared_ptr<Tensor> Network::get_output_tensor(size_t index) {
Network& Network::set_async_callback(const AsyncCallback& callback) { Network& Network::set_async_callback(const AsyncCallback& callback) {
LITE_ERROR_HANDLER_BEGIN LITE_ERROR_HANDLER_BEGIN
LITE_ASSERT(
!m_config.options.force_output_use_user_specified_memory,
"Async mode can't run with force_output_use_user_specified_memory which "
"output data is written to use specific memory.");
LITE_CHECK_NON_NULL_POINTER(m_impl); LITE_CHECK_NON_NULL_POINTER(m_impl);
m_impl->set_async_callback(std::move(callback)); m_impl->set_async_callback(std::move(callback));
return *this; return *this;
......
...@@ -397,6 +397,73 @@ TEST(TestNetWork, ResetOutput) { ...@@ -397,6 +397,73 @@ TEST(TestNetWork, ResetOutput) {
compare_lite_tensor<float>(output_tensor, result_mgb); compare_lite_tensor<float>(output_tensor, result_mgb);
} }
TEST(TestNetWork, OutputNoCopy) {
Config config;
config.options.force_output_use_user_specified_memory = true;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
std::string input_name = "data";
auto result_mgb = mgb_lar(model_path, config, input_name, tensor);
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name);
auto src_ptr = tensor->get_memory_ptr();
auto src_layout = tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
size_t times = 5;
std::vector<std::shared_ptr<Tensor>> result_tensors;
for (size_t i = 0; i < times; i++) {
auto tmp = std::make_shared<Tensor>(
LiteDeviceType::LITE_CPU,
Layout{{1, 1000}, 2, LiteDataType::LITE_FLOAT});
result_tensors.push_back(tmp);
}
for (size_t i = 0; i < times; i++) {
void* out_data = result_tensors[i]->get_memory_ptr();
output_tensor->reset(out_data, result_tensors[i]->get_layout());
network->forward();
network->wait();
ASSERT_EQ(output_tensor->get_memory_ptr(), out_data);
compare_lite_tensor<float>(output_tensor, result_mgb);
}
for (size_t i = 0; i < times; i++) {
compare_lite_tensor<float>(result_tensors[i], result_mgb);
}
}
TEST(TestNetWork, OutputDynamicAlloc) {
Config config;
config.options.force_output_dynamic_alloc = true;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
std::string input_name = "data";
auto result_mgb = mgb_lar(model_path, config, input_name, tensor);
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name);
auto src_ptr = tensor->get_memory_ptr();
auto src_layout = tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
size_t times = 5;
for (size_t i = 0; i < times; i++) {
network->forward();
network->wait();
compare_lite_tensor<float>(output_tensor, result_mgb);
}
}
TEST(TestNetWork, AsyncExec) { TEST(TestNetWork, AsyncExec) {
Config config; Config config;
config.options.var_sanity_check_first_run = false; config.options.var_sanity_check_first_run = false;
......
...@@ -507,13 +507,12 @@ void ComputingGraphImpl::dest_var_optimize(VarNodeArray& dest_vars) { ...@@ -507,13 +507,12 @@ void ComputingGraphImpl::dest_var_optimize(VarNodeArray& dest_vars) {
i->add_flag(F::NO_MEM_RECLAIM); i->add_flag(F::NO_MEM_RECLAIM);
} }
} }
if (dest_vars[0]->owner_graph()->options().force_output_write_to_user_memory) { if (dest_vars[0]->owner_graph()->options().force_output_use_user_specified_memory) {
for (auto&& i : dest_vars) { for (auto&& i : dest_vars) {
mgb_assert( mgb_assert(
!i->contain_flag(F::RT_FORCE_DYNAMIC_MEM_ALLOC), !i->contain_flag(F::RT_FORCE_DYNAMIC_MEM_ALLOC),
"var %s with force dynamic allocate should be set to write output " "var %s with RT_FORCE_DYNAMIC_MEM_ALLOC flag should not set "
"to " "force write output to user memory",
"user memory",
i->cname()); i->cname());
i->add_flag( i->add_flag(
F::NO_SYS_MEM_ALLOC | F::NO_SYS_STATIC_MEM_ALLOC | F::NO_SYS_MEM_ALLOC | F::NO_SYS_STATIC_MEM_ALLOC |
......
...@@ -574,6 +574,10 @@ MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) { ...@@ -574,6 +574,10 @@ MemAllocPlan& VarNode::init_mem_plan(const DeviceTensorND* fixed_alloc) {
return m_mem_plan; return m_mem_plan;
} }
bool VarNode::is_graph_dest_varnode() {
return ComputingGraphImpl::downcast(owner_graph())->var_receiver(this).size() == 0;
}
VarNode& VarNode::add_flag(Flag flag) { VarNode& VarNode::add_flag(Flag flag) {
modify_flag(flag, m_flag | flag); modify_flag(flag, m_flag | flag);
return *this; return *this;
...@@ -582,10 +586,13 @@ VarNode& VarNode::add_flag(Flag flag) { ...@@ -582,10 +586,13 @@ VarNode& VarNode::add_flag(Flag flag) {
void VarNode::modify_flag(Flag delta, Flag new_flag) { void VarNode::modify_flag(Flag delta, Flag new_flag) {
if (contain_flag(Flag::FLAG_FREEZED)) { if (contain_flag(Flag::FLAG_FREEZED)) {
mgb_assert( mgb_assert(
(delta & (Flag::NO_SYS_MEM_ALLOC | Flag::NO_MEM_RECLAIM | (delta & (Flag::NO_MEM_RECLAIM | Flag::NO_SYS_STATIC_MEM_ALLOC |
Flag::NO_SYS_STATIC_MEM_ALLOC | Flag::RT_FORCE_DYNAMIC_MEM_ALLOC | Flag::MEMORY_NO_NEED)) ==
Flag::RT_FORCE_DYNAMIC_MEM_ALLOC)) == delta || delta ||
(new_flag & Flag::MEMORY_NO_NEED)); is_graph_dest_varnode(),
"After the FLAG_FREEZED flag setting, var can only modify "
"NO_MEM_RECLAIM, NO_SYS_STATIC_MEM_ALLOC, RT_FORCE_DYNAMIC_MEM_ALLOC, "
"MEMORY_NO_NEED flag except graph dest var.");
mgb_assert( mgb_assert(
!ComputingGraphImpl::downcast(owner_graph()) !ComputingGraphImpl::downcast(owner_graph())
......
...@@ -421,7 +421,7 @@ public: ...@@ -421,7 +421,7 @@ public:
* Force the output to be written to the user specified memory, which * Force the output to be written to the user specified memory, which
* can optimize the copy of output data at one time * can optimize the copy of output data at one time
*/ */
bool force_output_write_to_user_memory = false; bool force_output_use_user_specified_memory = false;
//! whether to perform var sanity check on first run //! whether to perform var sanity check on first run
bool var_sanity_check_first_run = true; bool var_sanity_check_first_run = true;
......
...@@ -549,6 +549,10 @@ private: ...@@ -549,6 +549,10 @@ private:
MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag); MGE_WIN_DECLSPEC_FUC void modify_flag(Flag delta, Flag new_flag);
//! whether the var is graph output, if it is output, the Flag of
//! NO_SYS_MEM_ALLOC can be modified.
bool is_graph_dest_varnode();
MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor( MGE_WIN_DECLSPEC_FUC void assign_dev_tensor_from_tensor(
const DeviceTensorND& value); const DeviceTensorND& value);
......
...@@ -82,7 +82,7 @@ TEST(TestNoCopy, BasicInputNoCopy) { ...@@ -82,7 +82,7 @@ TEST(TestNoCopy, BasicInputNoCopy) {
TEST(TestNoCopy, IONoCopyPtrEQ) { TEST(TestNoCopy, IONoCopyPtrEQ) {
auto test_graph = TestGraph(); auto test_graph = TestGraph();
auto compute_graph = test_graph.m_network->graph; auto compute_graph = test_graph.m_network->graph;
compute_graph->options().force_output_write_to_user_memory = true; compute_graph->options().force_output_use_user_specified_memory = true;
test_graph.create_graph(); test_graph.create_graph();
auto func = test_graph.compile_without_copy(); auto func = test_graph.compile_without_copy();
auto&& outvar = func->get_output_vars()[0]; auto&& outvar = func->get_output_vars()[0];
...@@ -123,7 +123,7 @@ TEST(TestNoCopy, IONoCopyPtrEQ) { ...@@ -123,7 +123,7 @@ TEST(TestNoCopy, IONoCopyPtrEQ) {
TEST(TestNoCopy, IONoCopyCorrect) { TEST(TestNoCopy, IONoCopyCorrect) {
auto test_graph = TestGraph(); auto test_graph = TestGraph();
auto compute_graph = test_graph.m_network->graph; auto compute_graph = test_graph.m_network->graph;
compute_graph->options().force_output_write_to_user_memory = true; compute_graph->options().force_output_use_user_specified_memory = true;
test_graph.create_graph(); test_graph.create_graph();
HostTensorND truth; HostTensorND truth;
auto func = test_graph.compile_without_copy(); auto func = test_graph.compile_without_copy();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册