提交 56b94d89 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(dtr): add sqrt sampling

GitOrigin-RevId: 8cb2ceb520ad08070444c3a3e99da043b4ffe090
上级 8a73193c
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import mprop
from .dtr import *
mprop.init()
......@@ -9,13 +9,12 @@
import re
from typing import Union
from mprop import mproperty
from .core._imperative_rt.core2 import set_option as _set_option
from .core._imperative_rt.utils import _set_defrag
from ..core._imperative_rt.core2 import set_option as _set_option
from ..core._imperative_rt.utils import _set_defrag
_eviction_threshold = 0
_evictee_minimum_size = 1024 ** 2
_enable_sqrt_sampling = False
def _str2bytes(text: str) -> int:
......@@ -29,7 +28,7 @@ def _str2bytes(text: str) -> int:
return int(float(result[0][0]) * 1024 ** order.index(result[0][1].lower()))
@mproperty
@property
def eviction_threshold(mod):
r"""
Get or set the eviction threshold in bytes. It can also be set to a string,
......@@ -50,21 +49,22 @@ def eviction_threshold(mod):
mge.dtr.eviction_threshold = "2GB"
"""
return mod._eviction_threshold
return _eviction_threshold
@eviction_threshold.setter
def eviction_threshold(mod, value: Union[int, str]):
global _eviction_threshold
if isinstance(value, str):
mod._eviction_threshold = mod._str2bytes(value)
_eviction_threshold = _str2bytes(value)
elif isinstance(value, int):
mod._eviction_threshold = value
_eviction_threshold = value
else:
raise TypeError("`value` should be a str or an int")
_set_option("dtr_eviction_threshold", mod._eviction_threshold)
_set_option("dtr_eviction_threshold", _eviction_threshold)
@mproperty
@property
def evictee_minimum_size(mod):
r"""
Get or set the memory threshold of tensors in bytes. It can also be set to a
......@@ -85,18 +85,45 @@ def evictee_minimum_size(mod):
mge.dtr.evictee_minimum_size = "2MB"
"""
return mod._evictee_minimum_size
return _evictee_minimum_size
@evictee_minimum_size.setter
def evictee_minimum_size(mod, value: Union[int, str]):
global _evictee_minimum_size
if isinstance(value, str):
mod._evictee_minimum_size = mod._str2bytes(value)
_evictee_minimum_size = _str2bytes(value)
elif isinstance(value, int):
mod._evictee_minimum_size = value
_evictee_minimum_size = value
else:
raise TypeError("`value` should be a str or an int")
_set_option("dtr_evictee_minimum_size", mod._evictee_minimum_size)
_set_option("dtr_evictee_minimum_size", _evictee_minimum_size)
@property
def enable_sqrt_sampling(mod):
r"""
Get or set whether sqrt sampling is allowed. Sqrt sampling means that given
the size of the candidate set is N, only enumerate sqrt(N) tensors. When
the number of tensors is very high, enabling this optimization will speed
up the training.
Examples:
.. code-block::
import megengine as mge
mge.dtr.enable_sqrt_sampling = True
"""
return _enable_sqrt_sampling
@enable_sqrt_sampling.setter
def enable_sqrt_sampling(mod, value: bool):
global _enable_sqrt_sampling
_enable_sqrt_sampling = value
_set_option("enable_dtr_sqrt_sampling", _enable_sqrt_sampling)
def enable():
......
......@@ -761,7 +761,7 @@ bool ChannelImpl::auto_evict(size_t force_num=0) {
while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) {
RECORD_EVENT(AutoEvictEvent);
sample_on_device(m_dtr.comp_node, false);
auto best = m_dtr.find_best_tensor();
auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling && !force_num);
if (!best) {
break;
}
......@@ -988,8 +988,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
TensorInfo::ComputePath::make(cmd.id, cmd.op, cmd.inputs, cmd.outputs);
size_t detach_cnt = 0;
if (!strcmp(get_name(*cmd.op), "BatchNorm") && cmd.outputs.size() == 5) {
cmd.outputs[0]->detach_producer(); // detach running_mean
cmd.outputs[1]->detach_producer(); // detach running_var
for (auto input : cmd.inputs) {
input->ref_cnt -= 2;
}
}
for (auto output : cmd.outputs) {
if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
if (output->producer && !output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
output->detach_producer();
detach_cnt ++;
}
......@@ -1339,9 +1346,15 @@ double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
return cost;
}
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() {
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(bool enable_dtr_sqrt_sampling=false) {
double min_msps = -1;
TensorInfo* best = nullptr;
size_t sz = 1;
if (enable_dtr_sqrt_sampling) {
while (sz * sz <= candidates.size()) sz ++;
} else {
sz = candidates.size();
}
for (auto i : candidates) {
if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
double neighbor_cost = estimate_neighbor_cost(i);
......@@ -1354,6 +1367,7 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor() {
best = i;
}
}
if (--sz == 0) break;
}
return best;
}
......
......@@ -323,7 +323,7 @@ private:
* \return the pointer of the best tensor; nullptr is returned if no
* available tensor is found
*/
TensorInfo* find_best_tensor();
TensorInfo* find_best_tensor(bool);
/*!
* \brief estimate the cost of recomputing tensor ptr
......
......@@ -41,6 +41,7 @@ public:
DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1,
"enable host compute, thus computation may be done in host event if it's device is gpu.");
DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, "");
DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, "");
DEF_OPTION(dtr_eviction_threshold, "MEGENGINE_DTR_EVICTION_THRESHOLD", 0,
"auto drop will start whenever gpu memory usage exceeds this value.");
DEF_OPTION(dtr_evictee_minimum_size, "MEGENGINE_DTR_EVICTEE_MINIMUM_SIZE", 1048576,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册