From 261a5bce23b066b7c05c188ff638fa2c31b066f5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 26 May 2022 18:00:29 +0800 Subject: [PATCH] feat(imperative/amp): add dimshuffle in set_format for nhwc GitOrigin-RevId: 5ced9e1a31d78ea0628663ba9e3aa4942255713f --- .../python/megengine/amp/convert_format.py | 14 ++++---------- imperative/python/megengine/core/_config.py | 7 ------- imperative/python/megengine/functional/nn.py | 2 +- .../python/megengine/functional/tensor.py | 1 - .../python/megengine/optimizer/optimizer.py | 18 ++++++++---------- .../test/unit/amp/test_convert_format.py | 8 +++++--- imperative/src/impl/transformations/format.cpp | 5 ++++- 7 files changed, 22 insertions(+), 33 deletions(-) diff --git a/imperative/python/megengine/amp/convert_format.py b/imperative/python/megengine/amp/convert_format.py index 657ee1f99..28af3640a 100644 --- a/imperative/python/megengine/amp/convert_format.py +++ b/imperative/python/megengine/amp/convert_format.py @@ -23,23 +23,17 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): if not _is_nchw_format(x): return x - if x.ndim == 4: - pattern = (0, 2, 3, 1) - elif x.ndim == 5: - pattern = (0, 1, 3, 4, 2) - else: + if x.ndim != 4 and x.ndim != 5: raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) - # TODO: use initialization from tensor after fixing format setting if x.format != "nhwc": + # hostvalue should still be valid, so no d2h cost. + data = x.numpy() if inplace: - # hostvalue should still be valid, so no d2h cost. - data = x.numpy() # reset will destroy existed backward grad x[...] = Tensor(data, format="nhwc") else: # use mge interface to maintain grad - x = F.transpose(x, pattern) - x.format = "nhwc" + x = Tensor(data, format="nhwc") return x diff --git a/imperative/python/megengine/core/_config.py b/imperative/python/megengine/core/_config.py index dd1f4b4bc..e756b572b 100644 --- a/imperative/python/megengine/core/_config.py +++ b/imperative/python/megengine/core/_config.py @@ -181,7 +181,6 @@ def _reset_execution_config( deterministic_kernel=None, async_level=None, compute_mode=None, - auto_format_convert=None, ): global _benchmark_kernel, _deterministic_kernel, __compute_mode orig_flags = ( @@ -189,7 +188,6 @@ def _reset_execution_config( _deterministic_kernel, get_option("async_level"), __compute_mode, - get_auto_format_convert(), ) if benchmark_kernel is not None: _benchmark_kernel = benchmark_kernel @@ -199,8 +197,6 @@ def _reset_execution_config( set_option("async_level", async_level) if compute_mode is not None: __compute_mode = compute_mode - if auto_format_convert is not None: - set_auto_format_convert(auto_format_convert) return orig_flags @@ -211,7 +207,6 @@ def _override( deterministic_kernel=None, async_level=None, compute_mode=None, - auto_format_convert=None, ): r"""A context manager that users can opt in by attaching the decorator to set the config of the global variable. @@ -227,7 +222,6 @@ def _override( deterministic_kernel = Fasle, async_level=2, compute_mode="float32", - auto_format_convert=True, ) def train(): """ @@ -236,7 +230,6 @@ def _override( deterministic_kernel=deterministic_kernel, async_level=async_level, compute_mode=compute_mode, - auto_format_convert=auto_format_convert, ) try: yield diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 9650d6b0e..6f7cc79a0 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1206,9 +1206,9 @@ def batch_norm( if x is None: x = Const(value, inp.dtype, inp.device) - x.format = inp.format shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), x, shape) + result.format = inp.format return result else: assert x_ndim == 1 diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index b3aca0484..7ad187b00 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -274,7 +274,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: return x # set x's format to use FormatTransformation rule for Broadcast. - x.format = inp.format return broadcast_to(x, inp.shape) diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index ad783f832..412579da7 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -91,14 +91,13 @@ class Optimizer(metaclass=ABCMeta): else: param_group["params"] = list(param_group["params"]) - with _config._override(auto_format_convert=False): - for param in param_group["params"]: - if not isinstance(param, Parameter): - raise TypeError( - "optimizer can only optimize Parameters, but one of the params is " - + str(type(param)) - ) - param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) + for param in param_group["params"]: + if not isinstance(param, Parameter): + raise TypeError( + "optimizer can only optimize Parameters, but one of the params is " + + str(type(param)) + ) + param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) for name, default in self._defaults.items(): if default is required and name not in param_group: @@ -121,8 +120,7 @@ class Optimizer(metaclass=ABCMeta): def _add_state(self, param, state_name, initializer=None): if initializer is None: - with _config._override(auto_format_convert=False): - initializer = np.zeros(param.shape, dtype=np.float32) + initializer = np.zeros(param.shape, dtype=np.float32) state_dict = self._state.setdefault(param, {}) assert state_name not in state_dict state = Tensor(initializer, no_cache=True, format=param.format) diff --git a/imperative/python/test/unit/amp/test_convert_format.py b/imperative/python/test/unit/amp/test_convert_format.py index bfed1a16f..2724b829b 100644 --- a/imperative/python/test/unit/amp/test_convert_format.py +++ b/imperative/python/test/unit/amp/test_convert_format.py @@ -10,7 +10,8 @@ import pytest import megengine.functional as F import megengine.module as M -from megengine import Parameter, Tensor, amp, config +from megengine import Parameter, Tensor, amp +from megengine.core._config import set_auto_format_convert class MyModule(M.Module): @@ -56,5 +57,6 @@ def test_convert_module(is_inplace): m = amp.convert_module_format(m, is_inplace) for name, param in m.named_tensors(): assert param.format == "nhwc" - with config._override(auto_format_convert=False): - assert param.shape == expected_shape[name], name + set_auto_format_convert(False) + assert param.shape == expected_shape[name], name + set_auto_format_convert(True) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 179779680..a80d7fc52 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -19,6 +19,9 @@ TypedValueRef FormatTransformation::to( const std::string& scope) const { std::vector pattern; Format format = tensor.format(); + if (format == target) + return as(tensor, target); + if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { // FIXME(czh): temporary fast path for group conv 5D weight. if (tensor.value().shape().cast().ndim == 5) { @@ -618,7 +621,7 @@ ValueRefList FormatTransformation::apply_transformation( } else if (auto* _op = op.as()) { auto&& inp_ref = inputs[0].as_ref(m_value_type); mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); - return {m_value_type.make(inp_ref->value(), _op->format())}; + return {to(*inp_ref, _op->format().type(), "")}; } else if (op.is()) { auto&& inp_ref = inputs[0].as_ref(m_value_type); if (inp_ref) { -- GitLab