From d313f92610df10630d4d279f93577b157bd7489c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 28 May 2022 16:56:43 +0800 Subject: [PATCH] fix(imperative/amp): fix format transformation for symbol trans GitOrigin-RevId: 96cc237c67e25c8cb1567eb08325db65adc1c57d --- .../python/megengine/autodiff/grad_manager.py | 2 -- .../python/megengine/functional/tensor.py | 4 ++- imperative/python/src/transformation.h | 2 +- .../test/unit/amp/test_convert_format.py | 34 +++++++++++-------- .../test/unit/core/test_formatted_tensor.py | 6 +++- .../src/impl/transformations/format.cpp | 22 +++++++++--- .../imperative/transformations/format.h | 1 + 7 files changed, 47 insertions(+), 24 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 8f4923009..394494ae0 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -260,7 +260,6 @@ class GradManager: push_scope("backward") set_option("record_computing_path", 0) _origin_auto_format = get_auto_format_convert() - set_auto_format_convert(False) from ..functional import ones_like global backwarding_grad_manager @@ -304,7 +303,6 @@ class GradManager: self.release() backwarding_grad_manager = cache set_option("record_computing_path", 1) - set_auto_format_convert(_origin_auto_format) pop_scope("backward") def record(self): diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 7ad187b00..66f9ad65c 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: return x # set x's format to use FormatTransformation rule for Broadcast. - return broadcast_to(x, inp.shape) + rst = broadcast_to(x, inp.shape) + rst.format = inp.format + return rst def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 9edb44269..10d5df567 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -26,7 +26,7 @@ public: Eval, }; - std::array>, 8> segments; + std::array>, 9> segments; private: template diff --git a/imperative/python/test/unit/amp/test_convert_format.py b/imperative/python/test/unit/amp/test_convert_format.py index 2724b829b..d9749a148 100644 --- a/imperative/python/test/unit/amp/test_convert_format.py +++ b/imperative/python/test/unit/amp/test_convert_format.py @@ -12,6 +12,7 @@ import megengine.functional as F import megengine.module as M from megengine import Parameter, Tensor, amp from megengine.core._config import set_auto_format_convert +from megengine.core._trace_option import use_symbolic_shape class MyModule(M.Module): @@ -41,22 +42,25 @@ class MyModule(M.Module): def test_convert_module(is_inplace): m = MyModule() expected_shape = { - "i.bn.weight": (1, 1, 1, 4), - "i.bn.bias": (1, 1, 1, 4), - "i.bn.running_mean": (1, 1, 1, 4), - "i.bn.running_var": (1, 1, 1, 4), - "conv.weight": (2, 2, 4, 4, 2), - "conv.bias": (1, 1, 1, 4), - "bn.weight": (1, 1, 1, 4), - "bn.bias": (1, 1, 1, 4), - "bn.running_mean": (1, 1, 1, 4), - "bn.running_var": (1, 1, 1, 4), - "param": (1, 1, 1, 3), - "buff": (1, 1, 1, 3), + "i.bn.weight": (1, 4, 1, 1), + "i.bn.bias": (1, 4, 1, 1), + "i.bn.running_mean": (1, 4, 1, 1), + "i.bn.running_var": (1, 4, 1, 1), + "conv.weight": (2, 2, 2, 4, 4), + "conv.bias": (1, 4, 1, 1), + "bn.weight": (1, 4, 1, 1), + "bn.bias": (1, 4, 1, 1), + "bn.running_mean": (1, 4, 1, 1), + "bn.running_var": (1, 4, 1, 1), + "param": (1, 3, 1, 1), + "buff": (1, 3, 1, 1), } m = amp.convert_module_format(m, is_inplace) for name, param in m.named_tensors(): assert param.format == "nhwc" - set_auto_format_convert(False) - assert param.shape == expected_shape[name], name - set_auto_format_convert(True) + if use_symbolic_shape(): + np.testing.assert_array_equal( + param.shape.numpy(), expected_shape[name], name + ) + else: + assert param.shape == expected_shape[name], name diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 8c00c2334..659640f06 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -6,6 +6,7 @@ import megengine.functional as F import megengine.module as M from megengine import tensor from megengine.autodiff import GradManager +from megengine.core._trace_option import use_symbolic_shape from megengine.jit import trace @@ -121,7 +122,10 @@ def test_repeat(is_symbolic): @pytest.mark.parametrize("is_symbolic", [None]) def test_getshape(is_symbolic): def func(x): - return x.shape + if use_symbolic_shape(): + return x.shape.numpy() + else: + return x.shape data = np.arange(0, 24).reshape((1, 2, 3, 4)) _compare_nchw_nhwc(data, func, is_symbolic) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index a80d7fc52..83b7549eb 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -1,5 +1,6 @@ #include "megbrain/imperative/transformations/format.h" #include "megbrain/imperative/transformations/grad.h" +#include "megbrain/imperative/transformations/symbol.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/utility.h" @@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs( } return wrapped_outputs; } + +inline bool FormatTransformation::check_all_format_value( + const Span& inputs) const { + for (size_t i = 0; i < inputs.size(); ++i) { + if (!inputs[i].as_ref(m_value_type)) { + return false; + } + } + return true; +} + namespace { ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { @@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format( for (size_t i = 0; i < inputs.size(); ++i) { auto&& inp = inputs[i].cast(t.value_type()); if (inp.format() != dst_fmt && - inp.value().shape().cast().ndim == 4) { + (inp.value().shape().cast().ndim == 4 || + inp.value().shape().cast().ndim == 5)) { unified_inputs[i] = t.to(inp, dst_fmt, scope); } else { unified_inputs[i] = inputs[i]; @@ -568,6 +581,10 @@ struct FormatRuleRegistry { ValueRefList FormatTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* apply_op = op.as()) { + // bypass SymbolValue + if (!check_all_format_value(inputs)) { + return imperative::apply(op, unwrap_inputs(inputs)); + } // all inputs should be FormattedTensorValue auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); if (iter != format_rules.end()) { @@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation( auto&& format = inp_ref->format(); return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); } else { - mgb_log_warn( - "Not FormattedTensorValue input for IdentityLike op: %s, %s", - op.to_string().c_str(), inputs[0].to_string().c_str()); return imperative::apply(op, inputs); } } else if (op.is()) { diff --git a/imperative/src/include/megbrain/imperative/transformations/format.h b/imperative/src/include/megbrain/imperative/transformations/format.h index fcd572a1c..8ea235d64 100644 --- a/imperative/src/include/megbrain/imperative/transformations/format.h +++ b/imperative/src/include/megbrain/imperative/transformations/format.h @@ -70,6 +70,7 @@ public: const ValueRef& output, Format format = Format::Type::DEFAULT) const; inline ValueRefList wrap_outputs( const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const; + inline bool check_all_format_value(const Span& inputs) const; TypedValueRef as( const FormattedTensorValue&, const Format::Type& target) const; -- GitLab