diff --git a/imperative/python/test/unit/core/test_formatted_tensor.py b/imperative/python/test/unit/core/test_formatted_tensor.py index a80310c5179b23b04eafd74d77a122fe1e539986..b300ad8e323abf56e5a92896731c05fbd2307e33 100644 --- a/imperative/python/test/unit/core/test_formatted_tensor.py +++ b/imperative/python/test/unit/core/test_formatted_tensor.py @@ -224,7 +224,6 @@ def test_interpolate(mode, is_symbolic): assert rst.format == x.format return rst.numpy() - # NHWC interpolate only suppoted channel is 1 or 3 data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32") _compare_nchw_nhwc(data, func, is_symbolic) @@ -331,6 +330,35 @@ def test_pooling2d(pooling, is_symbolic): _compare_nchw_nhwc(data, func, is_symbolic) +@pytest.mark.skip("not implemented") +def test_where(): + def func(x): + mask = tensor( + np.array([True, False, False, True] * 6, dtype=np.bool).reshape( + (1, 2, 3, 4) + ) + ) + y = tensor( + np.array([1, np.inf, np.nan, 4] * 6, dtype=np.float32).reshape((1, 2, 3, 4)) + ) + out = F.where(mask, x, y) + assert out.format == "default" + return out.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + +def test_unsupported_op(): + def func(x): + rst = F.nn.pad(x, pad_width=((1, 1),), mode="constant") + assert rst.format == "default" + return rst.numpy() + + data = np.arange(0, 24).reshape((1, 2, 3, 4)) + _compare_nchw_nhwc(data, func) + + def _compare_backward(inps, model, is_symbolic=None): def func(*inps): return model(*inps) diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 3ce1bcee9566ef8e982fcbd2f075090938e1ab23..d4914056fd967dc662fe37db0c4c87b8b358f17a 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -18,20 +18,20 @@ TypedValueRef FormatTransformation::to( const FormattedTensorValue& tensor, const FT& target, const std::string& scope) const { std::vector pattern; - if (tensor.format() == FT::NHWC && target == FT::NCHW) { + Format format = tensor.format(); + if (format == FT::NHWC && (target == FT::NCHW || target == FT::DEFAULT)) { // FIXME(czh): temporary fast path for group conv 5D weight. if (tensor.value().shape().cast().ndim == 5) { pattern = {0, 1, 4, 2, 3}; } else { pattern = {0, 3, 1, 2}; } - } else if (tensor.format() == FT::NCHW && target == FT::NHWC) { + } else if ((format == FT::NCHW || format == FT::DEFAULT) && target == FT::NHWC) { pattern = {0, 2, 3, 1}; } else { mgb_throw( MegBrainError, "Unsupport format conversion from %s to %s", - tensor.format().to_string().c_str(), - Format(target).to_string().c_str()); + format.to_string().c_str(), Format(target).to_string().c_str()); } auto output = imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0]; @@ -84,7 +84,7 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { return out; } else { mgb_throw( - MegBrainError, "Unsupported shape ndim %u in GetAttr(Shape).", + MegBrainError, "Unsupported shape ndim %lu in GetAttr(Shape).", shape.ndim); } } @@ -189,7 +189,7 @@ ValueRefList reshape_rule( return t.wrap_outputs(outputs, FT::NHWC); } else { // will not maintain src's format - auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); + auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); auto outputs = imperative::apply(op, {nchw_src}); return t.wrap_outputs(outputs); } @@ -204,7 +204,7 @@ ValueRefList reshape_rule( return t.wrap_outputs(outputs, FT::NHWC); } else { // will not maintain src's format - auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); + auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); auto outputs = imperative::apply( op, SmallVector{nchw_src, t.unwrap_input(inputs[1])}); return t.wrap_outputs(outputs); @@ -229,7 +229,7 @@ ValueRefList broadcast_rule( return t.wrap_outputs(outputs, FT::NHWC); } else { // will not maintain src's format - auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); + auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); auto outputs = imperative::apply(op, {nchw_src}); return t.wrap_outputs(outputs); } @@ -244,7 +244,7 @@ ValueRefList broadcast_rule( return t.wrap_outputs(outputs, FT::NHWC); } else { // will not maintain src's format - auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); + auto nchw_src = t.to(src, FT::DEFAULT, op.scope())->value(); auto outputs = imperative::apply( op, SmallVector{nchw_src, t.unwrap_input(inputs[1])}); return t.wrap_outputs(outputs); @@ -323,7 +323,7 @@ ValueRefList setsubtensor_rule( auto nhwc_inputs = ValueRefList(inputs.size()); if (format == FT::DEFAULT || format == FT::NCHW) { // value for setsubtensor should transpose to match shape. - auto nhwc_value = t.to(*(t.as(value, FT::NCHW)), FT::NHWC); + auto nhwc_value = t.to(value, FT::NHWC); // make new inputs for setsubtensor nhwc_inputs[0] = src.value(); nhwc_inputs[1] = nhwc_value->value(); @@ -355,14 +355,15 @@ inline FT get_inputs_format(Span& inputs, const FormatTransformation& return format; } -inline ValueRefList unify_nhwc_inputs( - Span& inputs, std::string scope, const FormatTransformation& t) { +inline ValueRefList unify_inputs_format( + const Span& inputs, const FT& dst_fmt, const std::string& scope, + const FormatTransformation& t) { ValueRefList unified_inputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { auto&& inp = inputs[i].cast(t.value_type()); - if (inp.format() != FT::NHWC && + if (inp.format() != dst_fmt && inp.value().shape().cast().ndim == 4) { - unified_inputs[i] = t.to(*t.as(inp, FT::NCHW), FT::NHWC, scope); + unified_inputs[i] = t.to(inp, dst_fmt, scope); } else { unified_inputs[i] = inputs[i]; } @@ -375,7 +376,7 @@ ValueRefList elemwise_rule( const FormatTransformation& t) { FT format = get_inputs_format(inputs, t); if (format == FT::NHWC && auto_convert) { - auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); + auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); return t.wrap_outputs( imperative::apply(op, t.unwrap_inputs(unified_inputs)), format); } @@ -389,7 +390,7 @@ ValueRefList concat_rule( if (!(format == FT::NHWC && auto_convert)) { return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), format); } - auto unified_inputs = unify_nhwc_inputs(inputs, op.scope(), t); + auto unified_inputs = unify_inputs_format(inputs, FT::NHWC, op.scope(), t); // TODO: handle 5D NHWC Tensor from group conv auto axis = op.axis; if (axis == 2 || axis == 3) { @@ -460,6 +461,12 @@ ValueRefList adaptive_pooling_rule( #define FOREACH_FORMAT_POLICY_OP(cb) \ cb(Pooling) \ cb(Convolution) + +#define FOREACH_BYPASS_OP(cb) \ + cb(ParamPackSplit) \ + cb(ParamPackConcat) \ + cb(CollectiveComm) \ + cb(CheckNonFinite) // clang-format on // multi inputs op without params @@ -517,6 +524,15 @@ FOREACH_FORMAT_OP(CREATE_FORMAT_OP_RULE) } FOREACH_FORMAT_POLICY_OP(CREATE_FORMAT_POLICY_OP_RULE) +#define CREATE_BYPASS_OP_RULE(Op) \ + ValueRefList Op##_rule( \ + const Op& _op, Span& inputs, const bool& auto_convert, \ + const FormatTransformation& t) { \ + return t.wrap_outputs(imperative::apply(_op, t.unwrap_inputs(inputs))); \ + } +FOREACH_BYPASS_OP(CREATE_BYPASS_OP_RULE) +#undef CREATE_BYPASS_OP_RULE + #undef CREATE_FORMAT_OP_RULE #define REGISTER_OP_RULE(op) register_format_rule(op##_rule); struct FormatRuleRegistry { @@ -536,6 +552,7 @@ struct FormatRuleRegistry { FOREACH_IDENTITY_OP(REGISTER_OP_RULE) FOREACH_FORMAT_OP(REGISTER_OP_RULE) FOREACH_FORMAT_POLICY_OP(REGISTER_OP_RULE) + FOREACH_BYPASS_OP(REGISTER_OP_RULE) } } _; #undef REGISTER_OP_RULE @@ -549,10 +566,13 @@ ValueRefList FormatTransformation::apply_transformation( if (iter != format_rules.end()) { return iter->second(apply_op->op(), inputs, m_auto_convert, *this); } else { - return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs))); + auto unified_inputs = unify_inputs_format( + inputs, FT::DEFAULT, apply_op->op().scope(), *this); + return wrap_outputs(imperative::apply(op, unwrap_inputs(unified_inputs))); } } else if (auto* create_tensor = op.as()) { auto format = create_tensor->format(); + // TODO: add dimshuffle for nhwc format return {wrap_output(imperative::apply(op, inputs)[0], format)}; } else if (auto* get_attr = op.as()) { auto&& input = inputs.item(); @@ -570,7 +590,7 @@ ValueRefList FormatTransformation::apply_transformation( return {ShapeValue::make(shape)}; } case GetAttr::Value: { - auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); + auto nchw_src = unwrap_input(to(src, FT::DEFAULT, "")); return imperative::apply(op, {nchw_src}); } default: