From b9e0319fc9591501a93c3480bcc301634ba64563 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 Jul 2023 14:24:04 +0800 Subject: [PATCH] fix(imperative): fix format transformation handle nchw format tensor GitOrigin-RevId: f5838c1f7fbc1a1f4ffd9fc8951ed0cbdd422dc2 --- .../python/test/unit/core/test_formatted_tensor.py | 4 ++++ imperative/src/impl/transformations/format.cpp | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 659640f06..6b531b8f7 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -20,6 +20,10 @@ def test_basic(): b = tensor(a) assert b.format == "nhwc" + b = tensor(data, format="nchw") + result = F.pad(b, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="reflect") + assert result.format == "default" + # TODO: init from tensor with new format # c = tensor(a, format="nchw") # assert c.format == "nchw" diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 42fd66752..37c8fc9cb 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -435,13 +435,22 @@ inline FT get_inputs_format(Span& inputs, const FormatTransformation& return format; } +inline bool if_convert_format(const Format src_fmt, const FT& dst_fmt) { + if ((src_fmt == FT::NCHW && dst_fmt == FT::DEFAULT) || + (src_fmt == FT::DEFAULT && dst_fmt == FT::NCHW)) { + return false; + } else { + return true; + } +} + inline ValueRefList unify_inputs_format( const Span& inputs, const FT& dst_fmt, const std::string& scope, const FormatTransformation& t) { ValueRefList unified_inputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { auto&& inp = inputs[i].cast(t.value_type()); - if (inp.format() != dst_fmt) { + if (inp.format() != dst_fmt && if_convert_format(inp.format(), dst_fmt)) { unified_inputs[i] = t.to(inp, dst_fmt, scope); } else { unified_inputs[i] = inputs[i]; -- GitLab