提交 e393d1cf 编写于 作者: M Megvii Engine Team

feat(mge/amp): add convert_format module for NHWC training

GitOrigin-RevId: 1b41e1042c0107d2b63f7753d20121fa04b17bd2
上级 533fb5bf
......@@ -2,6 +2,7 @@ import mprop
from ..core.tensor.amp import *
from .autocast import autocast
from .convert_format import convert_module_format, convert_tensor_format
from .grad_scaler import GradScaler
mprop.init()
import functools
from ..core import _config
from ..core.tensor import amp
......@@ -50,24 +51,37 @@ class autocast:
self._origin_high = None
self._origin_low = None
self._origin_configs = None
def __enter__(self):
self._origin_enabled = amp._enabled
self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._enabled = self.enabled
amp._set_amp_dtype_autocast(self.enabled)
if not self.enabled:
return
self._origin_high = amp._get_amp_high_prec_dtype()
self._origin_low = amp._get_amp_low_prec_dtype()
amp._set_amp_high_prec_dtype(self.high_prec_dtype)
amp._set_amp_low_prec_dtype(self.low_prec_dtype)
self._origin_configs = _config._reset_execution_config(compute_mode="float32")
def __exit__(self, *args):
amp._enabled = self._origin_enabled
amp._set_amp_dtype_autocast(self._origin_enabled)
if not self.enabled:
return
amp._set_amp_high_prec_dtype(self._origin_high)
amp._set_amp_low_prec_dtype(self._origin_low)
_config._reset_execution_config(*self._origin_configs)
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not self.enabled:
return func(*args, **kwargs)
with self:
return func(*args, **kwargs)
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from .. import functional as F
from ..module import Module
from ..tensor import Tensor
def _is_nchw_format(param: Tensor):
# TODO: use better condition
return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc"
def convert_tensor_format(x: Tensor, inplace: bool = True):
"""Convert NCHW Tensor to NHWC Tensor."""
if x.ndim == 4:
pattern = (0, 2, 3, 1)
elif x.ndim == 5:
pattern = (0, 1, 3, 4, 2)
else:
raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
# TODO: use initialization from tensor after fixing format setting
if inplace:
x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc")
else:
x = Tensor(x.numpy().transpose(*pattern), format="nhwc")
return x
def convert_module_format(module: Module, inplace: bool = True):
"""Convert NCHW Module to NHWC Module."""
if not inplace:
module = deepcopy(module)
for name, param in module.named_tensors():
if _is_nchw_format(param):
# hostvalue should still be valid, so no d2h cost.
convert_tensor_format(param, inplace=True)
return module
import weakref
from typing import Callable, Iterable, List, Union
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core._imperative_rt.core2 import (
get_auto_format_convert,
pop_scope,
push_scope,
set_auto_format_convert,
set_option,
)
from ..core.autodiff.grad import Grad
from ..core.tensor.dtype import is_differentible_dtype
from ..logger import get_logger
......@@ -253,6 +259,8 @@ class GradManager:
"""
push_scope("backward")
set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
from ..functional import ones_like
global backwarding_grad_manager
......@@ -296,6 +304,7 @@ class GradManager:
self.release()
backwarding_grad_manager = cache
set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
pop_scope("backward")
def record(self):
......
......@@ -10,8 +10,10 @@ from ._imperative_rt.core2 import (
set_option,
)
# use "default" to distinguish it from None in _reset_execution_config
__compute_mode = "default"
__conv_format = "default"
__bn_format = "default"
_benchmark_kernel = False
_deterministic_kernel = False
......@@ -22,6 +24,8 @@ __all__ = [
"disable_memory_forwarding",
"_compute_mode",
"_conv_format",
"_bn_format",
"_auto_format_convert",
"_override",
]
......@@ -32,6 +36,7 @@ def benchmark_kernel(mod):
which means use heuristic to choose the fastest algorithm.
Examples:
.. code-block::
import megengine as mge
......@@ -55,6 +60,7 @@ def deterministic_kernel(mod):
which means the algorithm is not reproducible.
Examples:
.. code-block::
import megengine as mge
......@@ -75,6 +81,7 @@ def async_level(mod) -> int:
which means both device and user side errors are async.
Examples:
.. code-block::
import megengine as mge
......@@ -110,16 +117,17 @@ def disable_memory_forwarding(mod, disable: bool):
@property
def _compute_mode(mod):
r"""Get or set the precision of intermediate results. The default option is "default",
which means that no special requirements will be placed on. When set to 'float32', it
would be used for accumulator and intermediate result, but only effective when input and
r"""Get or set the precision of intermediate results for conv, matmul. The default
option is None and will fallback to "default". When set to "float32", it will
trigger mixed precision computation on TensorCore, but only effective when input and
output are of float16 dtype.
Examples:
.. code-block::
import megengine as mge
mge.config._compute_mode = "default"
mge.config._compute_mode = "float32"
"""
return __compute_mode
......@@ -132,7 +140,7 @@ def _compute_mode(mod, _compute_mode: str):
@property
def _conv_format(mod):
r"""Get or set convolution data/filter/output layout format. The default option is "default",
r"""Get or set convolution data/filter/output layout format. The default option is None,
which means that no special format will be placed on. There are all layout definitions
``NCHW`` layout: ``{N, C, H, W}``
......@@ -145,6 +153,7 @@ def _conv_format(mod):
``NCHW64`` layout: ``{N, C/64, H, W, 64}``
Examples:
.. code-block::
import megengine as mge
......@@ -159,12 +168,35 @@ def _conv_format(mod, format: str):
__conv_format = format
@property
def _bn_format(mod):
r"""Get or set batchnorm param layout format. The default option is None and will
fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c",
param format of batchnorm will be changed to NHWC.
Examples:
.. code-block::
import megengine as mge
mge.config._bn_format = "dim_111c"
"""
return __bn_format
@_bn_format.setter
def _bn_format(mod, format: str):
global __bn_format
__bn_format = format
@property
def _auto_format_convert(mod):
r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order.
The default value is False, which means no convert.
Examples:
.. code-block::
import megengine as mge
......@@ -184,15 +216,17 @@ def _reset_execution_config(
async_level=None,
compute_mode=None,
conv_format=None,
bn_format=None,
auto_format_convert=None,
):
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format
global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format
orig_flags = (
_benchmark_kernel,
_deterministic_kernel,
get_option("async_level"),
__compute_mode,
__conv_format,
__bn_format,
get_auto_format_convert(),
)
if benchmark_kernel is not None:
......@@ -205,6 +239,8 @@ def _reset_execution_config(
__compute_mode = compute_mode
if conv_format is not None:
__conv_format = conv_format
if bn_format is not None:
__bn_format = bn_format
if auto_format_convert is not None:
set_auto_format_convert(auto_format_convert)
......@@ -218,12 +254,14 @@ def _override(
async_level=None,
compute_mode=None,
conv_format=None,
bn_format=None,
auto_format_convert=None,
):
r"""A context manager that users can opt in by attaching the decorator to set
the config of the global variable.
Examples:
.. code-block::
import megengine as mge
......@@ -234,6 +272,7 @@ def _override(
async_level=2,
compute_mode="float32",
conv_format="NHWC",
bn_format="dim_111c",
auto_format_convert=True,
)
def train():
......@@ -244,6 +283,7 @@ def _override(
async_level,
compute_mode,
conv_format,
bn_format,
auto_format_convert,
)
try:
......@@ -254,4 +294,4 @@ def _override(
def _get_actual_op_param(function_param, config_param):
return function_param if config_param == "default" else config_param
return function_param if config_param is "default" else config_param
......@@ -10,13 +10,19 @@ from .._imperative_rt.core2 import (
_enabled = False
_set_amp_dtype_autocast(_enabled)
__all__ = [
"enabled",
"high_prec_dtype",
"low_prec_dtype",
]
@property
def enabled(mod):
r"""Get or set amp autocast mode enabled or not.
Examples:
.. code-block::
import megengine as mge
......@@ -36,9 +42,9 @@ def enabled(mod, enabled: bool):
def high_prec_dtype(mod):
r"""Get or set amp autocast mode's higher precision dtype. It will change the
target dtype in tensor casting for better precision. Default: float32.
Examples:
.. code-block::
import megengine as mge
......@@ -56,9 +62,9 @@ def high_prec_dtype(mod, dtype: str):
def low_prec_dtype(mod):
r"""Get or set amp autocast mode's lower precision dtype. It will change the
target dtype in tensor casting for better speed and memory. Default: float16.
Examples:
.. code-block::
import megengine as mge
......
......@@ -63,6 +63,7 @@ def _matmul(
assert dim1 > 0 and dim2 > 0
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result
......
......@@ -441,7 +441,6 @@ def deformable_conv2d(
or conv_mode.name == "CROSS_CORRELATION"
)
if amp._enabled:
compute_mode = "float32"
inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
else:
offset = offset.astype("float32")
......@@ -1182,7 +1181,6 @@ def batch_norm(
momentum: float = 0.9,
eps: float = 1e-5,
inplace: bool = True,
compute_mode="default",
param_dim="dim_1c11"
):
r"""Applies batch normalization to the input.
......
......@@ -19,7 +19,6 @@ class _BatchNorm(Module):
affine=True,
track_running_stats=True,
freeze=False,
compute_mode="default",
param_dim="dim_1c11",
**kwargs
):
......@@ -31,7 +30,6 @@ class _BatchNorm(Module):
self.track_running_stats = track_running_stats
self._track_running_stats_saved = track_running_stats
self.freeze = freeze
self.compute_mode = compute_mode
self.param_dim = param_dim
if self.freeze:
assert (
......@@ -106,7 +104,6 @@ class _BatchNorm(Module):
or ((self.running_mean is None) and (self.running_var is None)),
momentum=exponential_average_factor,
eps=self.eps,
compute_mode=self.compute_mode,
param_dim=self.param_dim,
)
......
......@@ -8,7 +8,13 @@ from typing import Union
import numpy as np
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core._imperative_rt.core2 import (
get_auto_format_convert,
pop_scope,
push_scope,
set_auto_format_convert,
set_option,
)
from ..core.tensor.utils import set_convert_inputs
from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated
......@@ -90,7 +96,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param.numpy(), no_cache=True))
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for name, default in self._defaults.items():
if default is required and name not in param_group:
......@@ -139,6 +145,8 @@ class Optimizer(metaclass=ABCMeta):
# set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates
set_option("record_computing_path", 0)
_origin_auto_format = get_auto_format_convert()
set_auto_format_convert(False)
if self._disable_type_convert:
backup = set_convert_inputs(False)
for group in self.param_groups:
......@@ -155,6 +163,7 @@ class Optimizer(metaclass=ABCMeta):
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
set_option("record_computing_path", 1)
set_auto_format_convert(_origin_auto_format)
return self
@deprecated(version="1.0", reason="use clear_grad instead")
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
import megengine.functional as F
import megengine.module as M
from megengine import Parameter, Tensor, amp, tensor
class MyModule(M.Module):
class InnerModule(M.Module):
def __init__(self):
super().__init__()
self.bn = M.BatchNorm2d(4)
def forward(self, x):
return self.bn(x)
def __init__(self):
super().__init__()
self.i = self.InnerModule()
self.conv = M.Conv2d(4, 4, 4, groups=2)
self.bn = M.BatchNorm2d(4)
self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))
def forward(self, x):
x = self.i(x)
x = self.bn(x)
return x
@pytest.mark.parametrize("is_inplace", [False, True])
def test_convert_module(is_inplace):
m = MyModule()
m = amp.convert_module_format(m, is_inplace)
for name, param in m.named_tensors():
assert param.format == "nhwc"
......@@ -8,14 +8,27 @@ from megengine.autodiff import GradManager
def test_basic():
a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc")
data = np.arange(0, 24).reshape((1, 2, 3, 4))
# init from numpy
a = tensor(data, format="nhwc")
assert a.format == "nhwc"
# init from tensor
b = tensor(a)
assert b.format == "nhwc"
# TODO: fix Tensor init bug for another Tensor
# TODO: init from tensor with new format
# c = tensor(a, format="nchw")
# assert c.format == "nchw"
# TODO: reset from numpy
# b[...] = data
# assert b.format == "nhwc"
# reset from tensor
b[...] = tensor(data, format="nchw")
assert b.format == "nchw"
def _compare_nchw_nhwc(data, func):
x1 = tensor(data, format="nchw")
......@@ -23,7 +36,7 @@ def _compare_nchw_nhwc(data, func):
out1 = func(x1)
with mge.config._override(auto_format_convert=True):
out2 = func(x2)
np.testing.assert_equal(out1, out2)
np.testing.assert_almost_equal(out1, out2, decimal=5)
def test_dimshuffle():
......@@ -296,8 +309,10 @@ def test_backward():
with gm:
with mge.config._override(auto_format_convert=True, conv_format="NHWC"):
x = F.conv2d(x, w, b)
# TODO: fix manually convert to NHWC, usually used in detection head
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2)
gm.backward(x)
# TODO: backward grad has no format yet
# backward grad has no format
np.testing.assert_equal(
w.grad.numpy(),
np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
......
......@@ -921,12 +921,7 @@ def test_batchnorm2d_autocast():
amp.enabled = False
expected = F.batch_norm(
inp.astype("float16"),
weight=weight,
bias=bias,
training=True,
inplace=False,
compute_mode="float32",
inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False,
)
assert out.dtype == np.float16
assert expected.dtype == np.float16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册