提交 1a21dfdf 编写于 作者: M Megvii Engine Team

refactor(dtr): import dtr as submodule

GitOrigin-RevId: abecd0f176e6bb292cd19f129b11b43be41f891c
上级 6964576e
...@@ -76,7 +76,6 @@ from .core._imperative_rt.core2 import full_sync as _full_sync ...@@ -76,7 +76,6 @@ from .core._imperative_rt.core2 import full_sync as _full_sync
from .core._imperative_rt.core2 import sync as _sync from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import * from .device import *
from .dtr import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
from .tensor import Parameter, Tensor, tensor from .tensor import Parameter, Tensor, tensor
...@@ -100,6 +99,7 @@ del _set_fork_exec_path_for_timed_func ...@@ -100,6 +99,7 @@ del _set_fork_exec_path_for_timed_func
import megengine.autodiff import megengine.autodiff
import megengine.data import megengine.data
import megengine.distributed import megengine.distributed
import megengine.dtr
import megengine.functional import megengine.functional
import megengine.hub import megengine.hub
import megengine.jit import megengine.jit
......
...@@ -11,14 +11,14 @@ from typing import Union ...@@ -11,14 +11,14 @@ from typing import Union
from mprop import mproperty from mprop import mproperty
from .core._imperative_rt.core2 import set_option from .core._imperative_rt.core2 import set_option as _set_option
from .core._imperative_rt.utils import _set_defrag from .core._imperative_rt.utils import _set_defrag
_eviction_threshold = 0 _eviction_threshold = 0
_evictee_minimum_size = 1024 ** 2 _evictee_minimum_size = 1024 ** 2
def str2bytes(text: str) -> int: def _str2bytes(text: str) -> int:
regex = re.compile(r"(\d+(?:\.\d+)?)\s*([kmg]?b)", re.IGNORECASE) regex = re.compile(r"(\d+(?:\.\d+)?)\s*([kmg]?b)", re.IGNORECASE)
order = ["b", "kb", "mb", "gb"] order = ["b", "kb", "mb", "gb"]
result = regex.findall(text) result = regex.findall(text)
...@@ -32,7 +32,9 @@ def str2bytes(text: str) -> int: ...@@ -32,7 +32,9 @@ def str2bytes(text: str) -> int:
@mproperty @mproperty
def eviction_threshold(mod): def eviction_threshold(mod):
r""" r"""
Returns the eviction threshold in bytes. Get or set the eviction threshold in bytes. It can also be set to a string,
whose formatting supports byte(B), kilobyte(KB), megabyte(MB) and
gigabyte(GB) units.
.. note:: .. note::
...@@ -40,40 +42,34 @@ def eviction_threshold(mod): ...@@ -40,40 +42,34 @@ def eviction_threshold(mod):
and evict resident tensors until the amount of used memory falls below and evict resident tensors until the amount of used memory falls below
this threshold. this threshold.
"""
return mod._eviction_threshold
@eviction_threshold.setter
def eviction_threshold(mod, value: Union[int, str]):
r"""
Change the eviction threshold. If `value` is an int, it represents the
number of bytes. If `value` is a string, its formatting supports bytes(B),
kilobyte(KB), megabyte(MB) and gigabyte(GB) units.
Examples: Examples:
.. code-block:: .. code-block::
import megengine as mge import megengine as mge
mge.dtr.eviction_threshold = 2 * 1024 ** 3
mge.dtr.eviction_threshold = "2GB" mge.dtr.eviction_threshold = "2GB"
mge.dtr.eviction_threshold = "2048MB"
""" """
return mod._eviction_threshold
@eviction_threshold.setter
def eviction_threshold(mod, value: Union[int, str]):
if isinstance(value, str): if isinstance(value, str):
mod._eviction_threshold = mod.str2bytes(value) mod._eviction_threshold = mod._str2bytes(value)
elif isinstance(value, int): elif isinstance(value, int):
mod._eviction_threshold = value mod._eviction_threshold = value
else: else:
raise TypeError("`value` should be a str or an int") raise TypeError("`value` should be a str or an int")
set_option("dtr_eviction_threshold", mod._eviction_threshold) _set_option("dtr_eviction_threshold", mod._eviction_threshold)
@mproperty @mproperty
def evictee_minimum_size(mod): def evictee_minimum_size(mod):
r""" r"""
Returns the memory threshold of tensors in bytes. Get or set the memory threshold of tensors in bytes. It can also be set to a
string, whose formatting supports byte(B), kilobyte(KB), megabyte(MB) and
gigabyte(GB) units.
.. note:: .. note::
...@@ -81,34 +77,26 @@ def evictee_minimum_size(mod): ...@@ -81,34 +77,26 @@ def evictee_minimum_size(mod):
candidate set. A tensor that is not added to the candidate set will candidate set. A tensor that is not added to the candidate set will
never be evicted during its lifetime. never be evicted during its lifetime.
"""
return mod._evictee_minimum_size
@evictee_minimum_size.setter
def evictee_minimum_size(mod, value: Union[int, str]):
r"""
Change the memory threshold of tensors. If `value` is an int, it represents
the number of bytes. If `value` is a string, its formatting supports bytes(B),
kilobyte(KB), megabyte(MB) and gigabyte(GB) units.
Examples: Examples:
.. code-block:: .. code-block::
import megengine as mge import megengine as mge
mge.dtr.evictee_minimum_size = 2 * 1024 ** 2
mge.dtr.evictee_minimum_size = "2MB" mge.dtr.evictee_minimum_size = "2MB"
mge.dtr.evictee_minimum_size = "2048KB"
""" """
return mod._evictee_minimum_size
@evictee_minimum_size.setter
def evictee_minimum_size(mod, value: Union[int, str]):
if isinstance(value, str): if isinstance(value, str):
mod._evictee_minimum_size = mod.str2bytes(value) mod._evictee_minimum_size = mod._str2bytes(value)
elif isinstance(value, int): elif isinstance(value, int):
mod._evictee_minimum_size = value mod._evictee_minimum_size = value
else: else:
raise TypeError("`value` should be a str or an int") raise TypeError("`value` should be a str or an int")
set_option("dtr_evictee_minimum_size", mod._evictee_minimum_size) _set_option("dtr_evictee_minimum_size", mod._evictee_minimum_size)
def enable(): def enable():
...@@ -116,16 +104,16 @@ def enable(): ...@@ -116,16 +104,16 @@ def enable():
Enable to record computing path of tensors and to perform DTR policy. Enable to record computing path of tensors and to perform DTR policy.
""" """
_set_defrag(True) _set_defrag(True)
set_option("enable_dtr_auto_drop", 1) _set_option("enable_dtr_auto_drop", 1)
set_option("enable_drop", 1) _set_option("enable_drop", 1)
set_option("buffer_length", 0) _set_option("buffer_length", 0)
set_option("record_computing_path", 1) _set_option("record_computing_path", 1)
def disable(): def disable():
r""" r"""
Stop recording computing path of tensors and performing DTR policy. Stop recording computing path of tensors and performing DTR policy.
""" """
set_option("enable_dtr_auto_drop", 0) _set_option("enable_dtr_auto_drop", 0)
set_option("enable_drop", 0) _set_option("enable_drop", 0)
set_option("record_computing_path", 0) _set_option("record_computing_path", 0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册