From fc633ce4ffb935a32f7429dc8f26175b94f4a86c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 21 Mar 2022 19:55:16 +0800 Subject: [PATCH] fix(imperative/amp): fix custom grad in Subgraph GitOrigin-RevId: 1c728d6ab97e8a49f84bf7e309a288938111d7be --- imperative/python/megengine/amp/autocast.py | 38 ++--- .../python/megengine/amp/convert_format.py | 11 +- .../python/megengine/amp/grad_scaler.py | 4 +- imperative/python/megengine/core/_config.py | 11 +- .../python/megengine/core/autodiff/grad.py | 2 + .../python/megengine/core/tensor/utils.py | 1 + imperative/python/megengine/functional/nn.py | 8 +- .../python/megengine/module/batchnorm.py | 3 - .../python/megengine/optimizer/optimizer.py | 8 +- imperative/python/megengine/optimizer/sgd.py | 1 + .../test/unit/amp/test_convert_format.py | 18 +- .../test/unit/core/test_formatted_tensor.py | 114 ++++++++++--- .../src/impl/transformations/format.cpp | 158 +++++++++++++----- .../imperative/transformations/format.h | 5 +- .../imperative/transformations/grad.h | 2 + 15 files changed, 280 insertions(+), 104 deletions(-) diff --git a/imperative/python/megengine/amp/autocast.py b/imperative/python/megengine/amp/autocast.py index 33e06a35e..e104e0df3 100644 --- a/imperative/python/megengine/amp/autocast.py +++ b/imperative/python/megengine/amp/autocast.py @@ -50,36 +50,36 @@ class autocast: self._origin_enabled = None self._origin_high = None self._origin_low = None + self._origin_compute_mode = None self._origin_configs = None def __enter__(self): - self._origin_enabled = amp._enabled - amp._enabled = self.enabled - amp._set_amp_dtype_autocast(self.enabled) - if not self.enabled: - return + if self.enabled: + self._origin_enabled = amp._enabled + self._origin_high = amp._get_amp_high_prec_dtype() + self._origin_low = amp._get_amp_low_prec_dtype() + amp._enabled = self.enabled + amp._set_amp_dtype_autocast(self.enabled) + amp._set_amp_high_prec_dtype(self.high_prec_dtype) + amp._set_amp_low_prec_dtype(self.low_prec_dtype) - self._origin_high = amp._get_amp_high_prec_dtype() - self._origin_low = amp._get_amp_low_prec_dtype() - amp._set_amp_high_prec_dtype(self.high_prec_dtype) - amp._set_amp_low_prec_dtype(self.low_prec_dtype) - - self._origin_configs = _config._reset_execution_config(compute_mode="float32") + self._origin_configs = _config._reset_execution_config( + compute_mode="float32" + ) def __exit__(self, *args): - amp._enabled = self._origin_enabled - amp._set_amp_dtype_autocast(self._origin_enabled) - if not self.enabled: - return - amp._set_amp_high_prec_dtype(self._origin_high) - amp._set_amp_low_prec_dtype(self._origin_low) + if self.enabled: + amp._enabled = self._origin_enabled + amp._set_amp_dtype_autocast(self._origin_enabled) + amp._set_amp_high_prec_dtype(self._origin_high) + amp._set_amp_low_prec_dtype(self._origin_low) + + _config._reset_execution_config(*self._origin_compute_mode) def __call__(self, func): @functools.wraps(func) def wrapper(*args, **kwargs): - if not self.enabled: - return func(*args, **kwargs) with self: return func(*args, **kwargs) diff --git a/imperative/python/megengine/amp/convert_format.py b/imperative/python/megengine/amp/convert_format.py index 30a32baa1..3eca860af 100644 --- a/imperative/python/megengine/amp/convert_format.py +++ b/imperative/python/megengine/amp/convert_format.py @@ -10,6 +10,7 @@ from copy import deepcopy from .. import functional as F from ..module import Module from ..tensor import Tensor +from ..core import _config def _is_nchw_format(param: Tensor): @@ -26,10 +27,12 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): else: raise ValueError("Unsupport tensor ndim {}".format(x.ndim)) # TODO: use initialization from tensor after fixing format setting - if inplace: - x[...] = Tensor(x.numpy().transpose(*pattern), format="nhwc") - else: - x = Tensor(x.numpy().transpose(*pattern), format="nhwc") + if x.format != "nhwc": + if inplace: + data = x.numpy().transpose(*pattern) + x[...] = Tensor(data, format="nhwc") + else: + x = Tensor(x.numpy().transpose(*pattern), format="nhwc") return x diff --git a/imperative/python/megengine/amp/grad_scaler.py b/imperative/python/megengine/amp/grad_scaler.py index 337c8f643..3853994a8 100644 --- a/imperative/python/megengine/amp/grad_scaler.py +++ b/imperative/python/megengine/amp/grad_scaler.py @@ -144,7 +144,9 @@ class GradScaler: def _check_gradients(self, grads, scale): if len(grads) == 0: return False - return _check_non_finite(grads, scale) + rst = _check_non_finite(grads, scale) + rst = rst.numpy() + return rst def update(self, new_scale: float = None): r"""Update the scale factor according to whether encountered overflow grad. diff --git a/imperative/python/megengine/core/_config.py b/imperative/python/megengine/core/_config.py index 61c93194a..a9e2102e5 100644 --- a/imperative/python/megengine/core/_config.py +++ b/imperative/python/megengine/core/_config.py @@ -182,7 +182,6 @@ def _reset_execution_config( deterministic_kernel=None, async_level=None, compute_mode=None, - bn_format=None, auto_format_convert=None, ): global _benchmark_kernel, _deterministic_kernel, __compute_mode @@ -234,11 +233,11 @@ def _override( def train(): """ orig_flags = _reset_execution_config( - benchmark_kernel, - deterministic_kernel, - async_level, - compute_mode, - auto_format_convert, + benchmark_kernel=benchmark_kernel, + deterministic_kernel=deterministic_kernel, + async_level=async_level, + compute_mode=compute_mode, + auto_format_convert=auto_format_convert, ) try: yield diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index f9e182ef9..799f6e510 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -64,7 +64,9 @@ class Grad: continue grad.suppress() + print("before backward") self._impl.backward(ys, dys) + print("after backward") for grad in group: if grad is self: diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index afa1f08fe..1626b610e 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -24,6 +24,7 @@ from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from .._imperative_rt.ops import jit_supported from .._wrap import as_device from ..autodiff.grad import Function +from .. import _config from ..ops import builtin from .amp import _get_amp_high_prec_dtype, _get_amp_low_prec_dtype from .dtype import is_dtype_equal, is_quantize diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 0c72e3ba9..34def5e2d 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1226,12 +1226,16 @@ def batch_norm( bias = make_full_if_none(bias, 0) if not training: - op = builtin.BatchNorm(fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps) + op = builtin.BatchNorm( + fwd_mode=BatchNorm.FwdMode.INFERENCE, param_dim="dim_1c11", epsilon=eps + ) ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] return ret else: - op = builtin.BatchNorm(avg_factor=1 - momentum, epsilon=eps) + op = builtin.BatchNorm( + avg_factor=1 - momentum, param_dim="dim_1c11", epsilon=eps + ) if has_mean or has_var: running_mean = make_full_if_none(running_mean, 0) running_var = make_full_if_none(running_var, 1) diff --git a/imperative/python/megengine/module/batchnorm.py b/imperative/python/megengine/module/batchnorm.py index 22162ff7f..72d1f3d73 100644 --- a/imperative/python/megengine/module/batchnorm.py +++ b/imperative/python/megengine/module/batchnorm.py @@ -19,7 +19,6 @@ class _BatchNorm(Module): affine=True, track_running_stats=True, freeze=False, - param_dim="dim_1c11", **kwargs ): super(_BatchNorm, self).__init__(**kwargs) @@ -30,7 +29,6 @@ class _BatchNorm(Module): self.track_running_stats = track_running_stats self._track_running_stats_saved = track_running_stats self.freeze = freeze - self.param_dim = param_dim if self.freeze: assert ( self._track_running_stats_saved @@ -104,7 +102,6 @@ class _BatchNorm(Module): or ((self.running_mean is None) and (self.running_var is None)), momentum=exponential_average_factor, eps=self.eps, - param_dim=self.param_dim, ) return output diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index b4e624a1a..8c8a11c6a 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -8,6 +8,7 @@ from typing import Union import numpy as np +from ..core import _config from ..core._imperative_rt.core2 import ( get_auto_format_convert, pop_scope, @@ -96,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): "optimizer can only optimize Parameters, but one of the params is " + str(type(param)) ) - param._reset(Tensor(param.numpy(), no_cache=True, format=param.format)) + param._reset(Tensor(param, no_cache=True)) for name, default in self._defaults.items(): if default is required and name not in param_group: @@ -119,10 +120,11 @@ class Optimizer(metaclass=ABCMeta): def _add_state(self, param, state_name, initializer=None): if initializer is None: - initializer = np.zeros(param.shape, dtype=np.float32) + with _config._override(auto_format_convert=False): + initializer = np.zeros(param.shape, dtype=np.float32) state_dict = self._state.setdefault(param, {}) assert state_name not in state_dict - state = Tensor(initializer, no_cache=True) + state = Tensor(initializer, no_cache=True, format=param.format) state_dict[state_name] = state @abstractmethod diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 94a358728..caeafb360 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -5,6 +5,7 @@ from typing import Iterable, Union from ..functional.inplace import _inplace_add_ from ..tensor import Parameter, tensor from .optimizer import Optimizer +from ..core import _config class SGD(Optimizer): diff --git a/imperative/python/test/unit/amp/test_convert_format.py b/imperative/python/test/unit/amp/test_convert_format.py index a2b199be3..bfed1a16f 100644 --- a/imperative/python/test/unit/amp/test_convert_format.py +++ b/imperative/python/test/unit/amp/test_convert_format.py @@ -10,7 +10,7 @@ import pytest import megengine.functional as F import megengine.module as M -from megengine import Parameter, Tensor, amp, tensor +from megengine import Parameter, Tensor, amp, config class MyModule(M.Module): @@ -39,6 +39,22 @@ class MyModule(M.Module): @pytest.mark.parametrize("is_inplace", [False, True]) def test_convert_module(is_inplace): m = MyModule() + expected_shape = { + "i.bn.weight": (1, 1, 1, 4), + "i.bn.bias": (1, 1, 1, 4), + "i.bn.running_mean": (1, 1, 1, 4), + "i.bn.running_var": (1, 1, 1, 4), + "conv.weight": (2, 2, 4, 4, 2), + "conv.bias": (1, 1, 1, 4), + "bn.weight": (1, 1, 1, 4), + "bn.bias": (1, 1, 1, 4), + "bn.running_mean": (1, 1, 1, 4), + "bn.running_var": (1, 1, 1, 4), + "param": (1, 1, 1, 3), + "buff": (1, 1, 1, 3), + } m = amp.convert_module_format(m, is_inplace) for name, param in m.named_tensors(): assert param.format == "nhwc" + with config._override(auto_format_convert=False): + assert param.shape == expected_shape[name], name diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 4f4c5eacc..d1f3d4cfb 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -3,6 +3,7 @@ import pytest import megengine as mge import megengine.functional as F +import megengine.module as M from megengine import tensor from megengine.autodiff import GradManager from megengine.jit import trace @@ -36,9 +37,9 @@ def _compare_nchw_nhwc(data, func, is_symbolic=None): x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") if is_symbolic is not None: func = trace(func, symbolic=is_symbolic) - # out1 = func(x1) + out1 = func(x1) out2 = func(x2) - # np.testing.assert_almost_equal(out1, out2, decimal=5) + np.testing.assert_almost_equal(out1, out2, decimal=5) @pytest.mark.parametrize("is_symbolic", [None]) @@ -322,30 +323,91 @@ def test_pooling2d(pooling, is_symbolic): _compare_nchw_nhwc(data, func, is_symbolic) -@pytest.mark.parametrize("is_symbolic", [None]) -def test_backward(is_symbolic): - data = np.arange(0, 24).reshape((1, 2, 3, 4)) - x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") - w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") - b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") - gm = GradManager().attach([w, b]) +def _compare_backward(inps, model, is_symbolic=None): + def func(*inps): + return model(*inps) - def func(x, w, b): - return F.conv2d(x, w, b) + if is_symbolic is not None: + func = trace(func, symbolic=is_symbolic) + gm = GradManager().attach(model.parameters()) with gm: - if is_symbolic is not None: - func = trace(func, symbolic=is_symbolic) - x = func(x, w, b) - assert x.format == "nhwc" - # test manually convert to NHWC, usually used in detection head - x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2) - gm.backward(x) - print("finish backward", x.format) - # backward grad has no format - np.testing.assert_equal( - w.grad.numpy(), np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), - ) - np.testing.assert_equal( - b.grad.numpy(), np.array([12, 12, 12]).reshape((1, 1, 1, 3)) - ) + rst = func(*inps) + gm.backward(rst) + expected_grads = [param.grad for param in model.parameters()] + + inps = [mge.amp.convert_tensor_format(inp) for inp in inps] + model = mge.amp.convert_module_format(model) + gm = GradManager().attach(model.parameters()) + with gm: + rst = func(*inps) + gm.backward(rst) + actual_grads = [param.grad for param in model.parameters()] + + for expected, actual in zip(expected_grads, actual_grads): + # print(param.grad) + np.testing.assert_equal(expected.numpy(), actual.numpy()) + + +@pytest.mark.parametrize("is_symbolic", [None]) +def test_backward_conv2d_dimshuffle(is_symbolic): + class Net(M.Module): + def __init__(self): + super().__init__() + self.conv = M.Conv2d(2, 3, 1) + + def forward(self, inp): + # test manually convert to NHWC, usually used in detection head + return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2) + + inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) + # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + # w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc") + # b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc") + # grads = [ + # np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)), + # np.array([12, 12, 12]).reshape((1, 1, 1, 3)), + # ] + _compare_backward([inp], Net(), is_symbolic) + + +@pytest.mark.parametrize("is_symbolic", [None]) +def test_backward_groupconv2d_bn(is_symbolic): + class Net(M.Module): + def __init__(self): + super().__init__() + self.conv = M.Conv2d(2, 2, 1, groups=2) + self.bn = M.BatchNorm2d(2) + + def forward(self, inp): + # test manually convert to NHWC, usually used in detection head + return self.bn(self.conv(inp)) + + inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) + _compare_backward([inp], Net(), is_symbolic) + # def func(x, w, b, bn_w, bn_b): + # x = F.conv2d(x, w, b, groups=2) + # x = F.batch_norm( + # x, + # running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + # running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + # weight=bn_w, + # bias=bn_b, + # training=True, + # inplace=True, + # ) + # return x + + # data = np.arange(0, 24).reshape((1, 2, 3, 4)) + # x = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + # w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc") + # b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") + # bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc") + # bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc") + # grads = [ + # np.array([66, 210]).reshape((2, 1, 1, 1, 1)), + # np.array([12, 12]).reshape((1, 1, 1, 2)), + # np.array([12, 12]).reshape((1, 1, 1, 2)), + # np.array([12, 12]).reshape((1, 1, 1, 2)), + # ] + # _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index b65de6170..54de6cb53 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -1,6 +1,8 @@ #include "megbrain/imperative/transformations/format.h" +#include "megbrain/imperative/transformations/grad.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/ops/utility.h" namespace mgb { namespace imperative { @@ -17,7 +19,12 @@ TypedValueRef FormatTransformation::to( const std::string& scope) const { std::vector pattern; if (tensor.format() == FT::NHWC && target == FT::NCHW) { - pattern = {0, 3, 1, 2}; + // FIXME(czh): temporary fast path for group conv 5D weight. + if (tensor.value().shape().cast().ndim == 5) { + pattern = {0, 1, 4, 2, 3}; + } else { + pattern = {0, 3, 1, 2}; + } } else if (tensor.format() == FT::NCHW && target == FT::NHWC) { pattern = {0, 2, 3, 1}; } else { @@ -65,12 +72,22 @@ inline ValueRefList FormatTransformation::wrap_outputs( namespace { ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { - mgb_assert(shape.ndim == 4); auto out = ValueShape(shape); - out[3] = shape[2]; - out[2] = shape[1]; - out[1] = shape[3]; - return out; + if (shape.ndim == 4) { + out[1] = shape[3]; + out[2] = shape[1]; + out[3] = shape[2]; + return out; + } else if (shape.ndim == 5) { + out[2] = shape[4]; + out[3] = shape[2]; + out[4] = shape[3]; + return out; + } else { + mgb_throw( + MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).", + shape.ndim); + } } using FormatRule = std::function& inputs, const FormatTransformation& t) { FT format(FT::DEFAULT); for (auto& inp : inputs) { - auto&& inp_ref = inp.as_ref(t.value_type()); - if (inp_ref && inp_ref->format() != FT::DEFAULT) { - mgb_assert(format == FT::DEFAULT || inp_ref->format() == format); - format = inp_ref->format().type(); + auto&& inp_format = inp.cast(t.value_type()).format(); + if (inp_format != FT::DEFAULT) { + mgb_assert(format == FT::DEFAULT || inp_format == format); + format = inp_format.type(); } } return format; @@ -308,13 +325,6 @@ ValueRefList concat_rule( format); } -ValueRefList elemwise_rule( - const Elemwise& op, Span& inputs, const bool& auto_convert, - const FormatTransformation& t) { - FT format = get_inputs_format(inputs, t); - return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); -} - ValueRefList identity_rule_helper( const OpDef& op, const Span& inputs, const FormatTransformation& t) { // mgb_assert(inputs.size() == 1); @@ -336,24 +346,49 @@ ValueRefList batchnorm_rule( return identity_rule_helper(op, inputs, t); } +ValueRefList checknonfinite_rule( + const CheckNonFinite& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { + auto&& inputs_ = t.unwrap_inputs(inputs); + auto&& outputs_ = imperative::apply(op, inputs_); + return t.wrap_outputs(outputs_); +} + // clang-format off -#define FOREACH_IDENTITY_OP(cb) \ - cb(Copy) \ - cb(FastpathCopy) \ - cb(TypeCvt) \ - cb(Dropout) \ +#define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \ + cb(Elemwise) \ + cb(CompiledOp) \ + cb(SubgraphOp) + +#define FOREACH_IDENTITY_OP(cb) \ + cb(Copy) \ + cb(FastpathCopy) \ + cb(TypeCvt) \ + cb(Dropout) \ cb(Identity) -#define FOREACH_FORMAT_OP(cb) \ - cb(AdaptivePooling) \ - cb(WarpAffine) \ +#define FOREACH_FORMAT_OP(cb) \ + cb(AdaptivePooling) \ + cb(WarpAffine) \ cb(Resize) -#define FOREACH_FORMAT_POLICY_OP(cb)\ - cb(Pooling) \ +#define FOREACH_FORMAT_POLICY_OP(cb) \ + cb(Pooling) \ cb(Convolution) // clang-format on +// multi inputs op without params +#define CREATE_MULTI_INPS_NO_PARAM_OP_RULE(Op) \ + ValueRefList Op##_rule( \ + const Op& _op, Span& inputs, const bool& auto_convert, \ + const FormatTransformation& t) { \ + FT format = get_inputs_format(inputs, t); \ + return t.wrap_outputs( \ + imperative::apply(_op, t.unwrap_inputs(inputs)), format); \ + } +FOREACH_MULTI_INPS_NO_PARAM_OP(CREATE_MULTI_INPS_NO_PARAM_OP_RULE) +#undef CREATE_MULTI_INPS_NO_PARAM_OP_RULE + // identity op #define CREATE_IDENTITY_OP_RULE(Op) \ ValueRefList Op##_rule( \ @@ -409,8 +444,9 @@ struct FormatRuleRegistry { register_format_rule(setsubtensor_rule); register_format_rule(setsubtensor_rule); register_format_rule(concat_rule); - register_format_rule(elemwise_rule); register_format_rule(batchnorm_rule); + register_format_rule(checknonfinite_rule); + FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE) FOREACH_IDENTITY_OP(REGISTER_OP_RULE) FOREACH_FORMAT_OP(REGISTER_OP_RULE) FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE) @@ -455,27 +491,73 @@ ValueRefList FormatTransformation::apply_transformation( return imperative::apply(op, unwrap_inputs(inputs)); } } else if (op.is()) { - bool is_formatted_tensor = inputs.item().is(m_value_type); - if (is_formatted_tensor) { - return {FormatValue::make(inputs[0].cast(m_value_type).format())}; + auto&& inp_ref = inputs[0].as_ref(m_value_type); + if (inp_ref) { + return {FormatValue::make(inp_ref->format())}; } else { mgb_log_warn( - "Not FormattedTensorValue input for GetFormat op: %s", - inputs[0].to_string().c_str()); + "Not FormattedTensorValue input for GetFormat op: %s, %s", + op.to_string().c_str(), inputs[0].to_string().c_str()); return {FormatValue::make(FT::DEFAULT)}; } } else if (op.is()) { - bool is_formatted_tensor = inputs.item().is(m_value_type); - if (is_formatted_tensor) { - auto&& format = inputs[0].cast(m_value_type).format(); + auto&& inp_ref = inputs[0].as_ref(m_value_type); + if (inp_ref) { + auto&& format = inp_ref->format(); return wrap_outputs( imperative::apply(op, unwrap_inputs(inputs)), format.type()); } else { mgb_log_warn( - "Not FormattedTensorValue input for IdentityLike op: %s", - inputs[0].to_string().c_str()); + "Not FormattedTensorValue input for IdentityLike op: %s, %s", + op.to_string().c_str(), inputs[0].to_string().c_str()); return imperative::apply(op, inputs); } + } else if (op.is()) { + auto&& inp_ref = inputs[0].as_ref(m_value_type); + if (inp_ref) { + auto format = inp_ref->format(); + GenericFunction callback = + (GenericFunction&)inputs[1].cast(); + GenericFunction new_callback = + [this, callback, format](Span inputs_) -> ValueRefList { + auto wrapped_inputs = SmallVector{ + this->value_type().make(inputs_.item(), format.type())}; + auto ret = callback(wrapped_inputs); + return ret; + }; + auto&& outputs = imperative::apply( + op, inp_ref->value(), FunctionValue::make(new_callback)); + return wrap_outputs(outputs, format.type()); + } else { + mgb_log_warn( + "Not FormattedTensorValue input for AttachGrad op: %s, %s", + op.to_string().c_str(), inputs[0].to_string().c_str()); + return imperative::apply(op, inputs); + } + } else if (auto* set_grad = op.as()) { + size_t nr_inputs = set_grad->nr_inputs(); + size_t nr_outputs = inputs.size() - nr_inputs; + Span inputs_ = {inputs.data(), nr_inputs}; + Span outputs_ = {inputs.data() + nr_inputs, nr_outputs}; + + // run original apply. + // grads needn't to unwrap and wrap, which will be unwrapped in GradTrans + auto&& outputs = imperative::apply(op, unwrap_inputs(inputs)); + + // handle output's formats + auto wrapped_outputs = ValueRefList(nr_outputs); + for (size_t i = 0; i < nr_outputs; ++i) { + if (auto output_ref = outputs_[i].as_ref(m_value_type)) { + wrapped_outputs[i] = + m_value_type.make(outputs[i], output_ref->format().type()); + } else { + mgb_log_warn( + "Not FormattedTensorValue outputs for SetGrad op: %s, %s", + op.to_string().c_str(), inputs_[i].to_string().c_str()); + wrapped_outputs[i] = m_value_type.make(outputs[i], FT::DEFAULT); + } + } + return wrapped_outputs; } else { return imperative::apply(op, unwrap_inputs(inputs)); } diff --git a/imperative/src/include/megbrain/imperative/transformations/format.h b/imperative/src/include/megbrain/imperative/transformations/format.h index 0d1b5c593..90427ee2b 100644 --- a/imperative/src/include/megbrain/imperative/transformations/format.h +++ b/imperative/src/include/megbrain/imperative/transformations/format.h @@ -47,7 +47,10 @@ public: const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { - mgb_assert(!value.is(m_value_type)); + //mgb_assert(!value.is(m_value_type)); + if (auto format_val = value.as_ref(m_value_type)) { + return format_val->value(); + } return value; } diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h index 7f4a280e9..ce392a11b 100644 --- a/imperative/src/include/megbrain/imperative/transformations/grad.h +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -377,6 +377,8 @@ public: SetGrad(GenericFunction grad_fn, size_t nr_inputs) : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} + std::shared_ptr key() const { return m_key; } + GenericFunction grad_fn() const { return m_grad_fn; } size_t nr_inputs() const { return m_nr_inputs; } -- GitLab