diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index 9366935f866125643eced6ee5a68d77197cad2b1..a80310c5179b23b04eafd74d77a122fe1e539986 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -193,7 +193,10 @@ def test_typecvt(is_symbolic): @pytest.mark.parametrize("is_symbolic", [None]) def test_elemwise(is_symbolic): def elemwise(x): - return (x * 2 + x / 2).numpy() + tmp = F.ones((1, 2, 3, 4)) + oup = x * tmp + x / 2 + assert oup.format == x.format + return oup.numpy() data = np.arange(0, 24).reshape((1, 2, 3, 4)) _compare_nchw_nhwc(data, elemwise, is_symbolic) @@ -202,7 +205,8 @@ def test_elemwise(is_symbolic): @pytest.mark.parametrize("is_symbolic", [None]) def test_concat(is_symbolic): def func(x): - rst = F.concat([x / 2, x * 2], axis=1) + tmp = F.ones((1, 2, 3, 4)) + rst = F.concat([x / 2, tmp], axis=1) assert rst.format == x.format return rst.numpy() diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 4fc94ad857419798930babba01ba7c4bb1e19ca6..3ce1bcee9566ef8e982fcbd2f075090938e1ab23 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -355,6 +355,33 @@ inline FT get_inputs_format(Span& inputs, const FormatTransformation& return format; } +inline ValueRefList unify_nhwc_inputs( + Span& inputs, std::string scope, const FormatTransformation& t) { + ValueRefList unified_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + auto&& inp = inputs[i].cast(t.value_type()); + if (inp.format() != FT::NHWC && + inp.value().shape().cast().ndim == 4) { + unified_inputs[i] = t.to(*t.as(inp, FT::NCHW), FT::NHWC, scope); + } else { + unified_inputs[i] = inputs[i]; + } + } + return unified_inputs; +} + +ValueRefList elemwise_rule( + const Elemwise& op, Span& inputs, const bool& auto_convert, + const FormatTransformation& t) { + FT format = get_inputs_format(inputs, t); + if (format == FT::NHWC && auto_convert) { + auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); + return t.wrap_outputs( + imperative::apply(op, t.unwrap_inputs(unified_inputs)), format); + } + return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); +} + ValueRefList concat_rule( const Concat& op, Span& inputs, const bool& auto_convert, const FormatTransformation& t) { @@ -362,6 +389,7 @@ ValueRefList concat_rule( if (!(format == FT::NHWC && auto_convert)) { return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); } + auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); // TODO: handle 5D NHWC Tensor from group conv auto axis = op.axis; if (axis == 2 || axis == 3) { @@ -372,7 +400,7 @@ ValueRefList concat_rule( return t.wrap_outputs( imperative::apply( *Concat::make(axis, op.comp_node, op.scope()), - t.unwrap_inputs(inputs)), + t.unwrap_inputs(unified_inputs)), format); } @@ -415,7 +443,6 @@ ValueRefList adaptive_pooling_rule( // clang-format off #define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \ - cb(Elemwise) \ cb(CompiledOp) \ cb(SubgraphOp) @@ -501,6 +528,7 @@ struct FormatRuleRegistry { register_format_rule(subtensor_rule); register_format_rule(setsubtensor_rule); register_format_rule(setsubtensor_rule); + register_format_rule(elemwise_rule); register_format_rule(concat_rule); register_format_rule(batchnorm_rule); register_format_rule(adaptive_pooling_rule); @@ -515,7 +543,6 @@ struct FormatRuleRegistry { ValueRefList FormatTransformation::apply_transformation( const Operator& op, Span inputs) { - // mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); if (auto* apply_op = op.as()) { // all inputs should be FormattedTensorValue auto iter = format_rules.find(apply_op->op().dyn_typeinfo());