From c9e56f4987e51c2fafd9c63c73ce0962c5d038bf Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 26 May 2022 15:36:27 +0800 Subject: [PATCH] feat(imperative/amp): add dimshuffle before creating nhwc tensor GitOrigin-RevId: 4461f9a0d3d6428282fcd858c677263d9a08fa9f --- .../python/megengine/amp/convert_format.py | 2 +- .../test/unit/core/test_formatted_tensor.py | 18 +++++++++--------- imperative/src/impl/transformations/format.cpp | 15 ++++++++++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/imperative/python/megengine/amp/convert_format.py b/imperative/python/megengine/amp/convert_format.py index 86d04aceb..657ee1f99 100644 --- a/imperative/python/megengine/amp/convert_format.py +++ b/imperative/python/megengine/amp/convert_format.py @@ -33,7 +33,7 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): if x.format != "nhwc": if inplace: # hostvalue should still be valid, so no d2h cost. - data = x.numpy().transpose(*pattern) + data = x.numpy() # reset will destroy existed backward grad x[...] = Tensor(data, format="nhwc") else: diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index b300ad8e3..8c00c2334 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -38,7 +38,7 @@ def test_basic(): def _compare_nchw_nhwc(data, func, is_symbolic=None): x1 = tensor(data) - x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") + x2 = tensor(data, format="nhwc") if is_symbolic is not None: func = trace(func, symbolic=is_symbolic) out1 = func(x1) @@ -247,8 +247,8 @@ def test_conv2d(is_symbolic): if x.format == "nhwc": x = F.conv2d( x, - weight=mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc"), - bias=mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc"), + weight=mge.tensor(np.ones((3, 2, 1, 1)), format="nhwc"), + bias=mge.tensor(np.ones((1, 3, 1, 1)), format="nhwc"), ) assert x.format == "nhwc" return x.numpy() @@ -265,8 +265,8 @@ def test_group_conv2d(is_symbolic): if x.format == "nhwc": x = F.conv2d( x, - weight=mge.tensor(np.ones((2, 2, 1, 1, 2)), format="nhwc"), - bias=mge.tensor(np.ones((1, 1, 1, 4)), format="nhwc"), + weight=mge.tensor(np.ones((2, 2, 2, 1, 1)), format="nhwc"), + bias=mge.tensor(np.ones((1, 4, 1, 1)), format="nhwc"), groups=2, ) assert x.format == "nhwc" @@ -286,10 +286,10 @@ def test_bn(is_symbolic): if x.format == "nhwc": oups = F.batch_norm( x.astype("float32"), - running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), - running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), - weight=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), - bias=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"), + running_mean=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), + running_var=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), + weight=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), + bias=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"), training=True, inplace=False, ) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index d4914056f..179779680 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -27,7 +27,11 @@ TypedValueRef FormatTransformation::to( pattern = {0, 3, 1, 2}; } } else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { - pattern = {0, 2, 3, 1}; + if (tensor.value().shape().cast().ndim == 5) { + pattern = {0, 1, 3, 4, 2}; + } else { + pattern = {0, 2, 3, 1}; + } } else { mgb_throw( MegBrainError, "Unsupport format conversion from %s to %s", @@ -572,8 +576,13 @@ ValueRefList FormatTransformation::apply_transformation( } } else if (auto* create_tensor = op.as()) { auto format = create_tensor->format(); - // TODO: add dimshuffle for nhwc format - return {wrap_output(imperative::apply(op, inputs)[0], format)}; + if (format == FT::NHWC) { + auto output = wrap_output(imperative::apply(op, inputs)[0]); + output = to(output.cast(m_value_type), FT::NHWC, ""); + return {output}; + } else { + return {wrap_output(imperative::apply(op, inputs)[0], format)}; + } } else if (auto* get_attr = op.as()) { auto&& input = inputs.item(); if (!input.is(m_value_type)) { -- GitLab