diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 8f492300956b8a5fb4052e5dd7a2f4893ae91b03..394494ae04a18ba577f75bb18304c8c94e8c9a69 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -260,7 +260,6 @@ class GradManager: push_scope("backward") set_option("record_computing_path", 0) _origin_auto_format = get_auto_format_convert() - set_auto_format_convert(False) from ..functional import ones_like global backwarding_grad_manager @@ -304,7 +303,6 @@ class GradManager: self.release() backwarding_grad_manager = cache set_option("record_computing_path", 1) - set_auto_format_convert(_origin_auto_format) pop_scope("backward") def record(self): diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 7ad187b004b6c66434779d40ece4d901fb474ca6..66f9ad65c7070d59803e7744253dc10fd289195f 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -274,7 +274,9 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: return x # set x's format to use FormatTransformation rule for Broadcast. - return broadcast_to(x, inp.shape) + rst = broadcast_to(x, inp.shape) + rst.format = inp.format + return rst def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 9edb442695ca21bccbca3156379a21e5defeadc3..10d5df5676c60148990c74a1d7ed27f6891807d7 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -26,7 +26,7 @@ public: Eval, }; - std::array>, 8> segments; + std::array>, 9> segments; private: template diff --git a/imperative/python/test/unit/amp/test_convert_format.py b/imperative/python/test/unit/amp/test_convert_format.py index 2724b829bd81857b8847bacd3d67cf9a15072457..d9749a14893904ae2978c4fbf253a4abdaab8bb3 100644 --- a/imperative/python/test/unit/amp/test_convert_format.py +++ b/imperative/python/test/unit/amp/test_convert_format.py @@ -12,6 +12,7 @@ import megengine.functional as F import megengine.module as M from megengine import Parameter, Tensor, amp from megengine.core._config import set_auto_format_convert +from megengine.core._trace_option import use_symbolic_shape class MyModule(M.Module): @@ -41,22 +42,25 @@ class MyModule(M.Module): def test_convert_module(is_inplace): m = MyModule() expected_shape = { - "i.bn.weight": (1, 1, 1, 4), - "i.bn.bias": (1, 1, 1, 4), - "i.bn.running_mean": (1, 1, 1, 4), - "i.bn.running_var": (1, 1, 1, 4), - "conv.weight": (2, 2, 4, 4, 2), - "conv.bias": (1, 1, 1, 4), - "bn.weight": (1, 1, 1, 4), - "bn.bias": (1, 1, 1, 4), - "bn.running_mean": (1, 1, 1, 4), - "bn.running_var": (1, 1, 1, 4), - "param": (1, 1, 1, 3), - "buff": (1, 1, 1, 3), + "i.bn.weight": (1, 4, 1, 1), + "i.bn.bias": (1, 4, 1, 1), + "i.bn.running_mean": (1, 4, 1, 1), + "i.bn.running_var": (1, 4, 1, 1), + "conv.weight": (2, 2, 2, 4, 4), + "conv.bias": (1, 4, 1, 1), + "bn.weight": (1, 4, 1, 1), + "bn.bias": (1, 4, 1, 1), + "bn.running_mean": (1, 4, 1, 1), + "bn.running_var": (1, 4, 1, 1), + "param": (1, 3, 1, 1), + "buff": (1, 3, 1, 1), } m = amp.convert_module_format(m, is_inplace) for name, param in m.named_tensors(): assert param.format == "nhwc" - set_auto_format_convert(False) - assert param.shape == expected_shape[name], name - set_auto_format_convert(True) + if use_symbolic_shape(): + np.testing.assert_array_equal( + param.shape.numpy(), expected_shape[name], name + ) + else: + assert param.shape == expected_shape[name], name diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 8c00c23347bae4b27e8707fb12122e750789b47a..659640f0629e5feb384cf1437d2e608f9212aa65 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -6,6 +6,7 @@ import megengine.functional as F import megengine.module as M from megengine import tensor from megengine.autodiff import GradManager +from megengine.core._trace_option import use_symbolic_shape from megengine.jit import trace @@ -121,7 +122,10 @@ def test_repeat(is_symbolic): @pytest.mark.parametrize("is_symbolic", [None]) def test_getshape(is_symbolic): def func(x): - return x.shape + if use_symbolic_shape(): + return x.shape.numpy() + else: + return x.shape data = np.arange(0, 24).reshape((1, 2, 3, 4)) _compare_nchw_nhwc(data, func, is_symbolic) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index a80d7fc5278e81ad471123422c7f85f2c223c68e..83b7549ebef094705631b3bc0510adc5e686b2c8 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -1,5 +1,6 @@ #include "megbrain/imperative/transformations/format.h" #include "megbrain/imperative/transformations/grad.h" +#include "megbrain/imperative/transformations/symbol.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/utility.h" @@ -75,6 +76,17 @@ inline ValueRefList FormatTransformation::wrap_outputs( } return wrapped_outputs; } + +inline bool FormatTransformation::check_all_format_value( + const Span& inputs) const { + for (size_t i = 0; i < inputs.size(); ++i) { + if (!inputs[i].as_ref(m_value_type)) { + return false; + } + } + return true; +} + namespace { ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { @@ -369,7 +381,8 @@ inline ValueRefList unify_inputs_format( for (size_t i = 0; i < inputs.size(); ++i) { auto&& inp = inputs[i].cast(t.value_type()); if (inp.format() != dst_fmt && - inp.value().shape().cast().ndim == 4) { + (inp.value().shape().cast().ndim == 4 || + inp.value().shape().cast().ndim == 5)) { unified_inputs[i] = t.to(inp, dst_fmt, scope); } else { unified_inputs[i] = inputs[i]; @@ -568,6 +581,10 @@ struct FormatRuleRegistry { ValueRefList FormatTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* apply_op = op.as()) { + // bypass SymbolValue + if (!check_all_format_value(inputs)) { + return imperative::apply(op, unwrap_inputs(inputs)); + } // all inputs should be FormattedTensorValue auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); if (iter != format_rules.end()) { @@ -628,9 +645,6 @@ ValueRefList FormatTransformation::apply_transformation( auto&& format = inp_ref->format(); return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); } else { - mgb_log_warn( - "Not FormattedTensorValue input for IdentityLike op: %s, %s", - op.to_string().c_str(), inputs[0].to_string().c_str()); return imperative::apply(op, inputs); } } else if (op.is()) { diff --git a/imperative/src/include/megbrain/imperative/transformations/format.h b/imperative/src/include/megbrain/imperative/transformations/format.h index fcd572a1ce98b27e29b51549687c1cf4db7f2e1b..8ea235d6480e2e709c0a3acd457777f2a0de6595 100644 --- a/imperative/src/include/megbrain/imperative/transformations/format.h +++ b/imperative/src/include/megbrain/imperative/transformations/format.h @@ -70,6 +70,7 @@ public: const ValueRef& output, Format format = Format::Type::DEFAULT) const; inline ValueRefList wrap_outputs( const ValueRefList& outputs, Format format = Format::Type::DEFAULT) const; + inline bool check_all_format_value(const Span& inputs) const; TypedValueRef as( const FormattedTensorValue&, const Format::Type& target) const;