提交 f7a5fe17 编写于 作者: M Megvii Engine Team

feat(mge/config): add a config option for memory forwarding

GitOrigin-RevId: 657326154c80751302be0f725ba489b7d6acbec1
上级 9ffc2c0a
...@@ -20,6 +20,7 @@ __all__ = [ ...@@ -20,6 +20,7 @@ __all__ = [
"benchmark_kernel", "benchmark_kernel",
"deterministic_kernel", "deterministic_kernel",
"async_level", "async_level",
"disable_memory_forwarding",
"_compute_mode", "_compute_mode",
"_conv_format", "_conv_format",
"_override", "_override",
...@@ -86,6 +87,25 @@ def async_level(mod, level: int): ...@@ -86,6 +87,25 @@ def async_level(mod, level: int):
set_option("async_level", level) set_option("async_level", level)
@property
def disable_memory_forwarding(mod) -> bool:
r"""Get or set config whether to disable memory forwarding. The default option is false,
which means storage may be shared among tensors.
Examples:
.. code-block::
import megengine as mge
mge.config.disable_memory_forwarding = False
"""
return bool(get_option("disable_memory_forwarding"))
@disable_memory_forwarding.setter
def disable_memory_forwarding(mod, disable: bool):
set_option("disable_memory_forwarding", disable)
@property @property
def _compute_mode(mod): def _compute_mode(mod):
r"""Get or set the precision of intermediate results. The default option is "default", r"""Get or set the precision of intermediate results. The default option is "default",
......
...@@ -120,7 +120,6 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() { ...@@ -120,7 +120,6 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
return blob->storage(); return blob->storage();
}; };
OpDef::set_allocator(custom_allocator); OpDef::set_allocator(custom_allocator);
BlobManager::inst()->set_allocator(custom_allocator);
} }
// Do not use m_xxx_state directly // Do not use m_xxx_state directly
...@@ -358,8 +357,8 @@ void ChannelImpl::dispatch_kernel( ...@@ -358,8 +357,8 @@ void ChannelImpl::dispatch_kernel(
init(info, std::move(desc)); init(info, std::move(desc));
// make sure desc's value is consistent with h_value // make sure desc's value is consistent with h_value
if (!info->desc.value.empty()) { if (!info->desc.value.empty()) {
info->h_value = HostTensorND::make_proxy(desc.value) info->h_value = HostTensorND::make_proxy(info->desc.value)
.proxy_to_comp_node(desc.comp_node); .proxy_to_comp_node(info->desc.comp_node);
} }
output_infos.push_back(info); output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info)); outputs->push_back(reinterpret_cast<Handle>(info));
...@@ -561,6 +560,15 @@ void ChannelImpl::set_option(std::string name, size_t value) { ...@@ -561,6 +560,15 @@ void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
state.options.set_option(name, value); state.options.set_option(name, value);
// FIXME
if (name == "enable_dtr_auto_drop" && value) {
auto custom_allocator = [&](CompNode device, size_t size) {
auto blob = Blob::make(device, size);
alloc_tensor_with_evict(blob.get());
return blob->storage();
};
BlobManager::inst()->set_allocator(custom_allocator);
}
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
m_worker.add_task( m_worker.add_task(
{Profiler::next_id(), SetOption{name, value}, {Profiler::next_id(), SetOption{name, value},
...@@ -598,7 +606,7 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) { ...@@ -598,7 +606,7 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
m_valid_handle.insert(reinterpret_cast<Handle>(info)); m_valid_handle.insert(reinterpret_cast<Handle>(info));
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated; info->status = TensorInfo::Allocated;
info->desc = desc; info->desc = std::move(desc);
} }
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) { void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
...@@ -694,7 +702,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { ...@@ -694,7 +702,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
} }
// in order to avoid performance impact, // in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled // memory forwarding is disabled when DTR is enabled
if (state.options.enable_dtr_auto_drop) { if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
ptr->to_contiguous_inplace(); ptr->to_contiguous_inplace();
} }
dest->desc.layout = ptr->layout(); dest->desc.layout = ptr->layout();
......
...@@ -44,6 +44,9 @@ public: ...@@ -44,6 +44,9 @@ public:
enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1,
"enable host compute, thus computation may be done in host event if it's " "enable host compute, thus computation may be done in host event if it's "
"device is gpu."); "device is gpu.");
DEF_OPTION(
disable_memory_forwarding, "MEGENGINE_DISABLE_MEMORY_FORWARDING", 0,
"disable memory forwarding, thus each tensor has its own storage.");
DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, ""); DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, "");
DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, ""); DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, "");
DEF_OPTION( DEF_OPTION(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册