diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 659640f0629e5feb384cf1437d2e608f9212aa65..6b531b8f768390397dec508044e782cac82a8454 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -20,6 +20,10 @@ def test_basic(): b = tensor(a) assert b.format == "nhwc" + b = tensor(data, format="nchw") + result = F.pad(b, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="reflect") + assert result.format == "default" + # TODO: init from tensor with new format # c = tensor(a, format="nchw") # assert c.format == "nchw" diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 42fd66752bacdded725188273b1bbee126a966aa..37c8fc9cb22a5883ad7da2990fccc92cd77b4dfd 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -435,13 +435,22 @@ inline FT get_inputs_format(Span& inputs, const FormatTransformation& return format; } +inline bool if_convert_format(const Format src_fmt, const FT& dst_fmt) { + if ((src_fmt == FT::NCHW && dst_fmt == FT::DEFAULT) || + (src_fmt == FT::DEFAULT && dst_fmt == FT::NCHW)) { + return false; + } else { + return true; + } +} + inline ValueRefList unify_inputs_format( const Span& inputs, const FT& dst_fmt, const std::string& scope, const FormatTransformation& t) { ValueRefList unified_inputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { auto&& inp = inputs[i].cast(t.value_type()); - if (inp.format() != dst_fmt) { + if (inp.format() != dst_fmt && if_convert_format(inp.format(), dst_fmt)) { unified_inputs[i] = t.to(inp, dst_fmt, scope); } else { unified_inputs[i] = inputs[i];