From e393d1cf65c7ea326414f757f3bf84c25b1a4e1b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 25 Jan 2022 18:00:41 +0800 Subject: [PATCH] feat(mge/amp): add convert_format module for NHWC training GitOrigin-RevId: 1b41e1042c0107d2b63f7753d20121fa04b17bd2 --- imperative/python/megengine/amp/__init__.py | 1 + imperative/python/megengine/amp/autocast.py | 18 ++++++- .../python/megengine/amp/convert_format.py | 45 ++++++++++++++++ .../python/megengine/autodiff/grad_manager.py | 11 +++- imperative/python/megengine/core/_config.py | 54 ++++++++++++++++--- .../python/megengine/core/tensor/amp.py | 18 ++++--- .../megengine/core/tensor/array_method.py | 1 + imperative/python/megengine/functional/nn.py | 2 - .../python/megengine/module/batchnorm.py | 3 -- .../python/megengine/optimizer/optimizer.py | 13 ++++- .../test/unit/amp/test_convert_format.py | 44 +++++++++++++++ .../test/unit/core/test_formatted_tensor.py | 23 ++++++-- .../test/unit/functional/test_functional.py | 7 +-- 13 files changed, 207 insertions(+), 33 deletions(-) create mode 100644 imperative/python/megengine/amp/convert_format.py create mode 100644 imperative/python/test/unit/amp/test_convert_format.py diff --git a/imperative/python/megengine/amp/__init__.py b/imperative/python/megengine/amp/__init__.py index 81344eaf1..57eb3a463 100644 --- a/imperative/python/megengine/amp/__init__.py +++ b/imperative/python/megengine/amp/__init__.py @@ -2,6 +2,7 @@ import mprop from ..core.tensor.amp import * from .autocast import autocast +from .convert_format import convert_module_format, convert_tensor_format from .grad_scaler import GradScaler mprop.init() diff --git a/imperative/python/megengine/amp/autocast.py b/imperative/python/megengine/amp/autocast.py index 2c068c373..f13636fb6 100644 --- a/imperative/python/megengine/amp/autocast.py +++ b/imperative/python/megengine/amp/autocast.py @@ -1,5 +1,6 @@ import functools +from ..core import _config from ..core.tensor import amp @@ -50,24 +51,37 @@ class autocast: self._origin_high = None self._origin_low = None + self._origin_configs = None + def __enter__(self): self._origin_enabled = amp._enabled - self._origin_high = amp._get_amp_high_prec_dtype() - self._origin_low = amp._get_amp_low_prec_dtype() amp._enabled = self.enabled amp._set_amp_dtype_autocast(self.enabled) + if not self.enabled: + return + + self._origin_high = amp._get_amp_high_prec_dtype() + self._origin_low = amp._get_amp_low_prec_dtype() amp._set_amp_high_prec_dtype(self.high_prec_dtype) amp._set_amp_low_prec_dtype(self.low_prec_dtype) + self._origin_configs = _config._reset_execution_config(compute_mode="float32") + def __exit__(self, *args): amp._enabled = self._origin_enabled amp._set_amp_dtype_autocast(self._origin_enabled) + if not self.enabled: + return amp._set_amp_high_prec_dtype(self._origin_high) amp._set_amp_low_prec_dtype(self._origin_low) + _config._reset_execution_config(*self._origin_configs) + def __call__(self, func): @functools.wraps(func) def wrapper(*args, **kwargs): + if not self.enabled: + return func(*args, **kwargs) with self: return func(*args, **kwargs) diff --git a/imperative/python/megengine/amp/convert_format.py b/imperative/python/megengine/amp/convert_format.py new file mode 100644 index 000000000..30a32baa1 --- /dev/null +++ b/imperative/python/megengine/amp/convert_format.py @@ -0,0 +1,45 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from copy import deepcopy + +from .. import functional as F +from ..module import Module +from ..tensor import Tensor + + +def _is_nchw_format(param: Tensor): + # TODO: use better condition + return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" + + +def convert_tensor_format(x: Tensor, inplace: bool = True): + """Convert NCHW Tensor to NHWC Tensor.""" + if x.ndim == 4: + pattern = (0, 2, 3, 1) + elif x.ndim == 5: + pattern = (0, 1, 3, 4, 2) + else: + raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) + # TODO: use initialization from tensor after fixing format setting + if inplace: + x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") + else: + x = Tensor(x.numpy().transpose(*pattern), format="nhwc") + return x + + +def convert_module_format(module: Module, inplace: bool = True): + """Convert NCHW Module to NHWC Module.""" + if not inplace: + module = deepcopy(module) + + for name, param in module.named_tensors(): + if _is_nchw_format(param): + # hostvalue should still be valid, so no d2h cost. + convert_tensor_format(param, inplace=True) + return module diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 881969f8f..8f4923009 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,7 +1,13 @@ import weakref from typing import Callable, Iterable, List, Union -from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option +from ..core._imperative_rt.core2 import ( + get_auto_format_convert, + pop_scope, + push_scope, + set_auto_format_convert, + set_option, +) from ..core.autodiff.grad import Grad from ..core.tensor.dtype import is_differentible_dtype from ..logger import get_logger @@ -253,6 +259,8 @@ class GradManager: """ push_scope("backward") set_option("record_computing_path", 0) + _origin_auto_format = get_auto_format_convert() + set_auto_format_convert(False) from ..functional import ones_like global backwarding_grad_manager @@ -296,6 +304,7 @@ class GradManager: self.release() backwarding_grad_manager = cache set_option("record_computing_path", 1) + set_auto_format_convert(_origin_auto_format) pop_scope("backward") def record(self): diff --git a/imperative/python/megengine/core/_config.py b/imperative/python/megengine/core/_config.py index 49877f2a0..f09dff5f9 100644 --- a/imperative/python/megengine/core/_config.py +++ b/imperative/python/megengine/core/_config.py @@ -10,8 +10,10 @@ from ._imperative_rt.core2 import ( set_option, ) +# use "default" to distinguish it from None in _reset_execution_config __compute_mode = "default" __conv_format = "default" +__bn_format = "default" _benchmark_kernel = False _deterministic_kernel = False @@ -22,6 +24,8 @@ __all__ = [ "disable_memory_forwarding", "_compute_mode", "_conv_format", + "_bn_format", + "_auto_format_convert", "_override", ] @@ -32,6 +36,7 @@ def benchmark_kernel(mod): which means use heuristic to choose the fastest algorithm. Examples: + .. code-block:: import megengine as mge @@ -55,6 +60,7 @@ def deterministic_kernel(mod): which means the algorithm is not reproducible. Examples: + .. code-block:: import megengine as mge @@ -75,6 +81,7 @@ def async_level(mod) -> int: which means both device and user side errors are async. Examples: + .. code-block:: import megengine as mge @@ -110,16 +117,17 @@ def disable_memory_forwarding(mod, disable: bool): @property def _compute_mode(mod): - r"""Get or set the precision of intermediate results. The default option is "default", - which means that no special requirements will be placed on. When set to 'float32', it - would be used for accumulator and intermediate result, but only effective when input and + r"""Get or set the precision of intermediate results for conv, matmul. The default + option is None and will fallback to "default". When set to "float32", it will + trigger mixed precision computation on TensorCore, but only effective when input and output are of float16 dtype. Examples: + .. code-block:: import megengine as mge - mge.config._compute_mode = "default" + mge.config._compute_mode = "float32" """ return __compute_mode @@ -132,7 +140,7 @@ def _compute_mode(mod, _compute_mode: str): @property def _conv_format(mod): - r"""Get or set convolution data/filter/output layout format. The default option is "default", + r"""Get or set convolution data/filter/output layout format. The default option is None, which means that no special format will be placed on. There are all layout definitions ``NCHW`` layout: ``{N, C, H, W}`` @@ -145,6 +153,7 @@ def _conv_format(mod): ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` Examples: + .. code-block:: import megengine as mge @@ -159,12 +168,35 @@ def _conv_format(mod, format: str): __conv_format = format +@property +def _bn_format(mod): + r"""Get or set batchnorm param layout format. The default option is None and will + fallback to "dim_1c11" which corresponds to NCHW format. When set to "dim_111c", + param format of batchnorm will be changed to NHWC. + + Examples: + + .. code-block:: + + import megengine as mge + mge.config._bn_format = "dim_111c" + """ + return __bn_format + + +@_bn_format.setter +def _bn_format(mod, format: str): + global __bn_format + __bn_format = format + + @property def _auto_format_convert(mod): r"""Automatically convert indexing params' order for NCHW Tensor to NHWC order. The default value is False, which means no convert. Examples: + .. code-block:: import megengine as mge @@ -184,15 +216,17 @@ def _reset_execution_config( async_level=None, compute_mode=None, conv_format=None, + bn_format=None, auto_format_convert=None, ): - global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format + global _benchmark_kernel, _deterministic_kernel, __compute_mode, __conv_format, __bn_format orig_flags = ( _benchmark_kernel, _deterministic_kernel, get_option("async_level"), __compute_mode, __conv_format, + __bn_format, get_auto_format_convert(), ) if benchmark_kernel is not None: @@ -205,6 +239,8 @@ def _reset_execution_config( __compute_mode = compute_mode if conv_format is not None: __conv_format = conv_format + if bn_format is not None: + __bn_format = bn_format if auto_format_convert is not None: set_auto_format_convert(auto_format_convert) @@ -218,12 +254,14 @@ def _override( async_level=None, compute_mode=None, conv_format=None, + bn_format=None, auto_format_convert=None, ): r"""A context manager that users can opt in by attaching the decorator to set the config of the global variable. Examples: + .. code-block:: import megengine as mge @@ -234,6 +272,7 @@ def _override( async_level=2, compute_mode="float32", conv_format="NHWC", + bn_format="dim_111c", auto_format_convert=True, ) def train(): @@ -244,6 +283,7 @@ def _override( async_level, compute_mode, conv_format, + bn_format, auto_format_convert, ) try: @@ -254,4 +294,4 @@ def _override( def _get_actual_op_param(function_param, config_param): - return function_param if config_param == "default" else config_param + return function_param if config_param is "default" else config_param diff --git a/imperative/python/megengine/core/tensor/amp.py b/imperative/python/megengine/core/tensor/amp.py index d38240f20..44b9c0b23 100644 --- a/imperative/python/megengine/core/tensor/amp.py +++ b/imperative/python/megengine/core/tensor/amp.py @@ -10,13 +10,19 @@ from .._imperative_rt.core2 import ( _enabled = False _set_amp_dtype_autocast(_enabled) +__all__ = [ + "enabled", + "high_prec_dtype", + "low_prec_dtype", +] + @property def enabled(mod): r"""Get or set amp autocast mode enabled or not. - + Examples: - + .. code-block:: import megengine as mge @@ -36,9 +42,9 @@ def enabled(mod, enabled: bool): def high_prec_dtype(mod): r"""Get or set amp autocast mode's higher precision dtype. It will change the target dtype in tensor casting for better precision. Default: float32. - + Examples: - + .. code-block:: import megengine as mge @@ -56,9 +62,9 @@ def high_prec_dtype(mod, dtype: str): def low_prec_dtype(mod): r"""Get or set amp autocast mode's lower precision dtype. It will change the target dtype in tensor casting for better speed and memory. Default: float16. - + Examples: - + .. code-block:: import megengine as mge diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 102af7f2c..ce5151a33 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -63,6 +63,7 @@ def _matmul( assert dim1 > 0 and dim2 > 0 maxdim = dim1 if dim1 > dim2 else dim2 compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) + if dim1 == 1 and dim2 == 1: # dispatch to Dot (result,) = apply(builtin.Dot(), inp1, inp2) return result diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 52900b21c..57718b2ce 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -441,7 +441,6 @@ def deformable_conv2d( or conv_mode.name == "CROSS_CORRELATION" ) if amp._enabled: - compute_mode = "float32" inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias) else: offset = offset.astype("float32") @@ -1182,7 +1181,6 @@ def batch_norm( momentum: float = 0.9, eps: float = 1e-5, inplace: bool = True, - compute_mode="default", param_dim="dim_1c11" ): r"""Applies batch normalization to the input. diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 76bc427ea..22162ff7f 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -19,7 +19,6 @@ class _BatchNorm(Module): affine=True, track_running_stats=True, freeze=False, - compute_mode="default", param_dim="dim_1c11", **kwargs ): @@ -31,7 +30,6 @@ class _BatchNorm(Module): self.track_running_stats = track_running_stats self._track_running_stats_saved = track_running_stats self.freeze = freeze - self.compute_mode = compute_mode self.param_dim = param_dim if self.freeze: assert ( @@ -106,7 +104,6 @@ class _BatchNorm(Module): or ((self.running_mean is None) and (self.running_var is None)), momentum=exponential_average_factor, eps=self.eps, - compute_mode=self.compute_mode, param_dim=self.param_dim, ) diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 2bed6bc78..b4e624a1a 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -8,7 +8,13 @@ from typing import Union import numpy as np -from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option +from ..core._imperative_rt.core2 import ( + get_auto_format_convert, + pop_scope, + push_scope, + set_auto_format_convert, + set_option, +) from ..core.tensor.utils import set_convert_inputs from ..tensor import Parameter, Tensor from ..utils.deprecation import deprecated @@ -90,7 +96,7 @@ class Optimizer(metaclass=ABCMeta): "optimizer can only optimize Parameters, but one of the params is " + str(type(param)) ) - param._reset(Tensor(param.numpy(), no_cache=True)) + param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) for name, default in self._defaults.items(): if default is required and name not in param_group: @@ -139,6 +145,8 @@ class Optimizer(metaclass=ABCMeta): # set the globle state `_enable_convert_inputs` to `False` to disable # the `convert_inputs` for param updates set_option("record_computing_path", 0) + _origin_auto_format = get_auto_format_convert() + set_auto_format_convert(False) if self._disable_type_convert: backup = set_convert_inputs(False) for group in self.param_groups: @@ -155,6 +163,7 @@ class Optimizer(metaclass=ABCMeta): # restore the globle state `_enable_convert_inputs` set_convert_inputs(backup) set_option("record_computing_path", 1) + set_auto_format_convert(_origin_auto_format) return self @deprecated(version="1.0", reason="use clear_grad instead") diff --git a/imperative/python/test/unit/amp/test_convert_format.py b/imperative/python/test/unit/amp/test_convert_format.py new file mode 100644 index 000000000..a2b199be3 --- /dev/null +++ b/imperative/python/test/unit/amp/test_convert_format.py @@ -0,0 +1,44 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np +import pytest + +import megengine.functional as F +import megengine.module as M +from megengine import Parameter, Tensor, amp, tensor + + +class MyModule(M.Module): + class InnerModule(M.Module): + def __init__(self): + super().__init__() + self.bn = M.BatchNorm2d(4) + + def forward(self, x): + return self.bn(x) + + def __init__(self): + super().__init__() + self.i = self.InnerModule() + self.conv = M.Conv2d(4, 4, 4, groups=2) + self.bn = M.BatchNorm2d(4) + self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32)) + self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32)) + + def forward(self, x): + x = self.i(x) + x = self.bn(x) + return x + + +@pytest.mark.parametrize("is_inplace", [False, True]) +def test_convert_module(is_inplace): + m = MyModule() + m = amp.convert_module_format(m, is_inplace) + for name, param in m.named_tensors(): + assert param.format == "nhwc" diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 1f87d2f68..6c71fe0a9 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -8,14 +8,27 @@ from megengine.autodiff import GradManager def test_basic(): - a = tensor(np.arange(0, 24).reshape((1, 2, 3, 4)), dtype="float32", format="nhwc") + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + # init from numpy + a = tensor(data, format="nhwc") assert a.format == "nhwc" + + # init from tensor b = tensor(a) assert b.format == "nhwc" - # TODO: fix Tensor init bug for another Tensor + + # TODO: init from tensor with new format # c = tensor(a, format="nchw") # assert c.format == "nchw" + # TODO: reset from numpy + # b[...] = data + # assert b.format == "nhwc" + + # reset from tensor + b[...] = tensor(data, format="nchw") + assert b.format == "nchw" + def _compare_nchw_nhwc(data, func): x1 = tensor(data, format="nchw") @@ -23,7 +36,7 @@ def _compare_nchw_nhwc(data, func): out1 = func(x1) with mge.config._override(auto_format_convert=True): out2 = func(x2) - np.testing.assert_equal(out1, out2) + np.testing.assert_almost_equal(out1, out2, decimal=5) def test_dimshuffle(): @@ -296,8 +309,10 @@ def test_backward(): with gm: with mge.config._override(auto_format_convert=True, conv_format="NHWC"): x = F.conv2d(x, w, b) + # TODO: fix manually convert to NHWC, usually used in detection head + # x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) gm.backward(x) - # TODO: backward grad has no format yet + # backward grad has no format np.testing.assert_equal( w.grad.numpy(), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 44a186c2e..c91fe194a 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -921,12 +921,7 @@ def test_batchnorm2d_autocast(): amp.enabled = False expected = F.batch_norm( - inp.astype("float16"), - weight=weight, - bias=bias, - training=True, - inplace=False, - compute_mode="float32", + inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False, ) assert out.dtype == np.float16 assert expected.dtype == np.float16 -- GitLab