From 1fead9b6b06af656400788737246926428a96f36 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 22 Feb 2022 19:36:04 +0800 Subject: [PATCH] feat(gopt): merge consecutive dimshuffle and relayout to one relayout to optimize CD4 performace GitOrigin-RevId: 16f22baa80ddf228ebc3fc8b33c340ce4635e46f --- src/gopt/impl/inference.cpp | 40 +++++++++++++++++++++++++++++++ src/gopt/test/inference.cpp | 47 +++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 7af8073df..1dab6552e 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1000,6 +1000,46 @@ void ConvertFormatPass::apply(OptState& state) const { }; state.graph().iter(on_opr); rewriter.apply_inplace(); + + //! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) + + //! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4) + auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) { + auto opr_is_relayout = [](OperatorNodeBase* opr) { + return opr->try_cast_final(); + }; + auto opr_is_dimshuffle = [](OperatorNodeBase* opr) { + return opr->try_cast_final(); + }; + auto match_pattern = [](const opr::Dimshuffle::Param& param, + const std::vector&& patten) { + if (param.pattern_len == patten.size() && param.pattern[0] == patten[0] && + param.pattern[1] == patten[1] && param.pattern[2] == patten[2] && + param.pattern[3] == patten[3]) { + return true; + } + return false; + }; + auto this_opr_is_relayout = opr_is_relayout(opr); + auto prev_opr_is_dimshuffle = static_cast(nullptr); + if (this_opr_is_relayout) { + prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr()); + } + if (this_opr_is_relayout && prev_opr_is_dimshuffle) { + if (this_opr_is_relayout->param().mode == + megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I && + match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) { + auto inp = rewriter.get_var(prev_opr_is_dimshuffle->input(0)); + auto new_param = megdnn::param::RelayoutFormat(); + new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I; + auto new_opr = opr::RelayoutFormat::make(inp, new_param); + rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); + } + } else { + rewriter.auto_replace_outputs(opr); + } + }; + state.graph().iter(on_opr_merge); + rewriter.apply_inplace(); MIDOUT_E } diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 5d8b65a25..4f8aa6e9d 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1318,6 +1318,53 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise0) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, MergeDimShuffleAndRelayoutFormat) { + // hwcd4 is only supported in naive handle + NaiveMegDNNHandleScope naive_megdnn_handle; + + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + + auto host_x = gen({8, 8, 8, 8}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto d0 = opr::Dimshuffle::make(x, {0, 3, 1, 2}); + + auto a = mkvar("a", {1}); + auto b = mkvar("b", {1}); + auto y = d0 * a + b; + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nhwcd4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ( + megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I, + find_opr(y_opt).param().mode); + + ASSERT_EQ(0, find_opr_num(y_opt)); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.MergeDimShuffleAndRelayoutFormat.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile( + {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); + + *host_x = *gen({8, 8, 16, 16}, cn); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle; -- GitLab