diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 481c63b595b0952c45a9c128901339606f722cc5..f7aff5bc4d6a8c4833086ae3ee5926fe39507f41 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1038,7 +1038,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o 'NCHW_NCHW64 = 27', 'NCHW64_NCHW = 28', 'NCHW_NHWC = 29', - 'NHWC_NCHW = 30', + 'NHWC_NCHW = 30', + 'NHWCD4I_NHWC = 31', ) ) diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index 56a91c39f9ec241695469173da3401d1e2f952b4..478262f2f4b97093eb49393af3259c2ffb1fb5ef 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -114,6 +114,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds dst[3] = src[2]; dst[4] = 4; break; + case Param::Mode::NHWCD4I_NHWC: case Param::Mode::NHWCD4_NHWC: megdnn_assert(src.ndim == 5); dst.ndim = 4; @@ -331,6 +332,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { CHECK_SRC(DefaultTensorFormat::make()); dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); break; + case Param::Mode::NHWCD4I_NHWC: case Param::Mode::NHWCD4I_NCHW: CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); dst = DefaultTensorFormat::make(); @@ -594,6 +596,7 @@ void RelayoutFormat::deduce_exec_layout( .dimshuffle({0, 1, 3, 2, 4}); exec_dst = dst; break; + case Param::Mode::NHWCD4I_NHWC: case Param::Mode::NHWCD4_NHWC: // src is {N, H, CB, W, 4} // dst is {N, H, W, C}, diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 1dab6552ed2e521176a41285fc430453d0e46c18..10270bc3a394c14892767be7a0f29c3f009186b7 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1002,7 +1002,9 @@ void ConvertFormatPass::apply(OptState& state) const { rewriter.apply_inplace(); //! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) + - //! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4) + //! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4). Merge + //! consecutive relayout_format(NHWCD4 -> NCHW) + dimshuffle(NCHW -> NHWC) to one + //! relayout_format(NHWCD4 -> NHWC). auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) { auto opr_is_relayout = [](OperatorNodeBase* opr) { return opr->try_cast_final(); @@ -1019,23 +1021,48 @@ void ConvertFormatPass::apply(OptState& state) const { } return false; }; - auto this_opr_is_relayout = opr_is_relayout(opr); - auto prev_opr_is_dimshuffle = static_cast(nullptr); - if (this_opr_is_relayout) { - prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr()); - } - if (this_opr_is_relayout && prev_opr_is_dimshuffle) { - if (this_opr_is_relayout->param().mode == - megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I && - match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) { - auto inp = rewriter.get_var(prev_opr_is_dimshuffle->input(0)); - auto new_param = megdnn::param::RelayoutFormat(); - new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I; - auto new_opr = opr::RelayoutFormat::make(inp, new_param); - rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); + //! dimshuffle + relayout_format + { + auto this_opr_is_relayout = opr_is_relayout(opr); + auto prev_opr_is_dimshuffle = static_cast(nullptr); + if (this_opr_is_relayout) { + prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr()); + } + if (this_opr_is_relayout && prev_opr_is_dimshuffle) { + //! megengine only accept NCHW input + if (this_opr_is_relayout->param().mode == + megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I && + match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) { + auto inp = rewriter.get_var(prev_opr_is_dimshuffle->input(0)); + auto new_param = megdnn::param::RelayoutFormat(); + new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I; + auto new_opr = opr::RelayoutFormat::make(inp, new_param); + rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); + } + } else { + rewriter.auto_replace_outputs(opr); + } + } + //! relayout_format + dimshuffle + { + auto this_opr_is_dimshuffle = opr_is_dimshuffle(opr); + auto prev_opr_is_relayout = static_cast(nullptr); + if (this_opr_is_dimshuffle) { + prev_opr_is_relayout = opr_is_relayout(opr->input(0)->owner_opr()); + } + if (this_opr_is_dimshuffle && prev_opr_is_relayout) { + if (prev_opr_is_relayout->param().mode == + megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW && + match_pattern(this_opr_is_dimshuffle->param(), {0, 2, 3, 1})) { + auto inp = rewriter.get_var(prev_opr_is_relayout->input(0)); + auto new_param = megdnn::param::RelayoutFormat(); + new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NHWC; + auto new_opr = opr::RelayoutFormat::make(inp, new_param); + rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); + } + } else { + rewriter.auto_replace_outputs(opr); } - } else { - rewriter.auto_replace_outputs(opr); } }; state.graph().iter(on_opr_merge); diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 4f8aa6e9d4ab37fea6259398fd415866072d8540..2256840e6fd9f7c28553c43c32a12bfc2d4e19df 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1365,6 +1365,71 @@ TEST(TestGoptInference, MergeDimShuffleAndRelayoutFormat) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, MergeRelayoutFormatAndDimShuffle) { + // hwcd4 is only supported in naive handle + NaiveMegDNNHandleScope naive_megdnn_handle; + + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + + auto host_x = gen({2, 8, 16, 32}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + auto a = mkvar("a", {1}); + auto b = mkvar("b", {1}); + auto z = x * a + b; + + //! to NHWC + auto y = opr::Dimshuffle::make(z, {0, 2, 3, 1}); + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nhwcd4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ(0, find_opr_num(y_opt)); + auto check = [](SymbolVar endpoint) -> bool { + bool valid = true; + auto cb = [&](cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + auto mode = opr->try_cast_final()->param().mode; + //! The first relayout_format opr's mode is NCHW_NHWCD4I. The second is + //! NHWCD4I_NHWC + if (mode == megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I || + mode == megdnn::param::RelayoutFormat::Mode::NHWCD4I_NHWC) { + valid &= true; + } else { + valid &= false; + } + } + }; + cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); + return valid; + }; + ASSERT_EQ(true, check(y_opt)); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.MergeRelayoutFormatAndDimShuffle.json")); + + HostTensorND host_y; + HostTensorND host_y_opt; + auto func = graph->compile( + {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); + + *host_x = *gen({8, 8, 16, 16}, cn); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle;