提交 f7a5fe17 编写于 作者: M Megvii Engine Team

feat(mge/config): add a config option for memory forwarding

GitOrigin-RevId: 657326154c80751302be0f725ba489b7d6acbec1
上级 9ffc2c0a
......@@ -20,6 +20,7 @@ __all__ = [
"benchmark_kernel",
"deterministic_kernel",
"async_level",
"disable_memory_forwarding",
"_compute_mode",
"_conv_format",
"_override",
......@@ -86,6 +87,25 @@ def async_level(mod, level: int):
set_option("async_level", level)
@property
def disable_memory_forwarding(mod) -> bool:
r"""Get or set config whether to disable memory forwarding. The default option is false,
which means storage may be shared among tensors.
Examples:
.. code-block::
import megengine as mge
mge.config.disable_memory_forwarding = False
"""
return bool(get_option("disable_memory_forwarding"))
@disable_memory_forwarding.setter
def disable_memory_forwarding(mod, disable: bool):
set_option("disable_memory_forwarding", disable)
@property
def _compute_mode(mod):
r"""Get or set the precision of intermediate results. The default option is "default",
......
......@@ -120,7 +120,6 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
return blob->storage();
};
OpDef::set_allocator(custom_allocator);
BlobManager::inst()->set_allocator(custom_allocator);
}
// Do not use m_xxx_state directly
......@@ -358,8 +357,8 @@ void ChannelImpl::dispatch_kernel(
init(info, std::move(desc));
// make sure desc's value is consistent with h_value
if (!info->desc.value.empty()) {
info->h_value = HostTensorND::make_proxy(desc.value)
.proxy_to_comp_node(desc.comp_node);
info->h_value = HostTensorND::make_proxy(info->desc.value)
.proxy_to_comp_node(info->desc.comp_node);
}
output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info));
......@@ -561,6 +560,15 @@ void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
state.options.set_option(name, value);
// FIXME
if (name == "enable_dtr_auto_drop" && value) {
auto custom_allocator = [&](CompNode device, size_t size) {
auto blob = Blob::make(device, size);
alloc_tensor_with_evict(blob.get());
return blob->storage();
};
BlobManager::inst()->set_allocator(custom_allocator);
}
if (Profiler::is_profiling()) {
m_worker.add_task(
{Profiler::next_id(), SetOption{name, value},
......@@ -598,7 +606,7 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
m_valid_handle.insert(reinterpret_cast<Handle>(info));
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated;
info->desc = desc;
info->desc = std::move(desc);
}
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
......@@ -694,7 +702,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
}
// in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled
if (state.options.enable_dtr_auto_drop) {
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
ptr->to_contiguous_inplace();
}
dest->desc.layout = ptr->layout();
......
......@@ -44,6 +44,9 @@ public:
enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1,
"enable host compute, thus computation may be done in host event if it's "
"device is gpu.");
DEF_OPTION(
disable_memory_forwarding, "MEGENGINE_DISABLE_MEMORY_FORWARDING", 0,
"disable memory forwarding, thus each tensor has its own storage.");
DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, "");
DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, "");
DEF_OPTION(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册