From 5ebc9d50b7fc8b881e2d63d5cce4d1b4c3116f69 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Mar 2022 19:47:06 +0800 Subject: [PATCH] fix(pylite): fix lite global layout transform and fast run conflict error GitOrigin-RevId: 910c8da19f3c9973a088b70b782e91f6b366d4f9 --- lite/load_and_run/src/models/model_lite.h | 1 - lite/pylite/test/test_network.py | 26 ++++++++++++++++ lite/pylite/test/test_network_device.py | 28 +++++++++++++++++ lite/src/mge/network_impl.cpp | 14 ++++----- .../opr_format_modifier.cpp | 30 +++++++++---------- 5 files changed, 76 insertions(+), 23 deletions(-) diff --git a/lite/load_and_run/src/models/model_lite.h b/lite/load_and_run/src/models/model_lite.h index 88323beb9..2e687e202 100644 --- a/lite/load_and_run/src/models/model_lite.h +++ b/lite/load_and_run/src/models/model_lite.h @@ -40,7 +40,6 @@ public: void wait() override; //! enable global layout transform - void set_layout_transform(bool state) { enable_layout_transform = state; } //! get the network of lite model diff --git a/lite/pylite/test/test_network.py b/lite/pylite/test/test_network.py index aaeea5e44..ee502c20e 100644 --- a/lite/pylite/test/test_network.py +++ b/lite/pylite/test/test_network.py @@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet): fi = open("./model_afer_layoutTrans.mgb", "r") fi.close() os.remove("./model_afer_layoutTrans.mgb") + + def test_fast_run_and_global_layout_transform(self): + + config_ = LiteConfig() + network = LiteNetwork(config_) + fast_run_cache = "./algo_cache" + global_layout_transform_model = "./model_afer_layoutTrans.mgb" + network.set_network_algo_policy( + LiteAlgoSelectStrategy.LITE_ALGO_PROFILE + | LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED + ) + network.enable_global_layout_transform() + network.load(self.model_path) + self.do_forward(network) + network.dump_layout_transform_model(global_layout_transform_model) + LiteGlobal.dump_persistent_cache(fast_run_cache) + fi = open(fast_run_cache, "r") + fi.close() + fi = open(global_layout_transform_model, "r") + fi.close() + + LiteGlobal.set_persistent_cache(path=fast_run_cache) + self.do_forward(network) + + os.remove(fast_run_cache) + os.remove(global_layout_transform_model) diff --git a/lite/pylite/test/test_network_device.py b/lite/pylite/test/test_network_device.py index c240cb128..a5b6a6752 100644 --- a/lite/pylite/test/test_network_device.py +++ b/lite/pylite/test/test_network_device.py @@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda): fi = open("./model_afer_layoutTrans.mgb", "r") fi.close() os.remove("./model_afer_layoutTrans.mgb") + + @require_cuda() + def test_fast_run_and_global_layout_transform(self): + + config_ = LiteConfig() + config_.device_type = LiteDeviceType.LITE_CUDA + network = LiteNetwork(config_) + fast_run_cache = "./algo_cache" + global_layout_transform_model = "./model_afer_layoutTrans.mgb" + network.set_network_algo_policy( + LiteAlgoSelectStrategy.LITE_ALGO_PROFILE + | LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED + ) + network.enable_global_layout_transform() + network.load(self.model_path) + self.do_forward(network) + network.dump_layout_transform_model(global_layout_transform_model) + LiteGlobal.dump_persistent_cache(fast_run_cache) + fi = open(fast_run_cache, "r") + fi.close() + fi = open(global_layout_transform_model, "r") + fi.close() + + LiteGlobal.set_persistent_cache(path=fast_run_cache) + self.do_forward(network) + + os.remove(fast_run_cache) + os.remove(global_layout_transform_model) diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 04ec582a8..3b7d0ef68 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -422,6 +422,8 @@ void NetworkImplDft::load_model( m_load_result = m_loader->load(m_load_config, true); + modify_exection_policy(); + global_layout_transform(); adapt_option_valid(); @@ -436,7 +438,6 @@ void NetworkImplDft::load_model( } void NetworkImplDft::compile_graph() { - modify_exection_policy(); replace_dev_input_pass(); make_output_spec(); m_execute_func = m_load_result.graph_compile(m_output_spec); @@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy( if (static_cast(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) { dst_strategy = dst_strategy | S::OPTIMIZED; } - m_execution_policy = dst_strategy; + if (static_cast(dst_strategy) != 0) + m_execution_policy = dst_strategy; auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config; fast_run_config.binary_equal_between_batch = binary_equal_between_batch; @@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy( } void NetworkImplDft::modify_exection_policy() { - mgb::SymbolVarArray vars; - for (auto i : m_output_spec) { - vars.push_back(i.first); - } - if (static_cast(m_execution_policy) != 0) + auto& vars = m_load_result.output_var_list; + if (static_cast(m_execution_policy) != 0) { mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy); + } } //! set opr algorithm selection strategy in the network diff --git a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp index 5f9e27228..70e605734 100644 --- a/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp +++ b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp @@ -289,21 +289,21 @@ namespace intl { template struct OprFormatModifier; -#define INST(_Opr) \ - template <> \ - struct OprFormatModifier<_Opr> { \ - using OprFormat = typename _Opr::Param::Format; \ - static VarNode* make( \ - OprFormat opr_format, const VarNodeArray& i, \ - const cg::OperatorNodeBase* opr_) { \ - MIDOUT_B(_Opr) \ - auto&& opr = opr_->cast_final_safe<_Opr>(); \ - auto param = opr.param(); \ - param.format = opr_format; \ - return OprWithPolicyMaker<_Opr>::make( \ - i, param, opr.execution_policy(), opr.config()); \ - MIDOUT_E \ - } \ +#define INST(_Opr) \ + template <> \ + struct OprFormatModifier<_Opr> { \ + using OprFormat = typename _Opr::Param::Format; \ + static VarNode* make( \ + OprFormat opr_format, const VarNodeArray& i, \ + const cg::OperatorNodeBase* opr_) { \ + MIDOUT_B(_Opr) \ + auto&& opr = opr_->cast_final_safe<_Opr>(); \ + auto param = opr.param(); \ + param.format = opr_format; \ + return OprWithPolicyMaker<_Opr>::make( \ + i, param, opr.execution_policy_transient(), opr.config()); \ + MIDOUT_E \ + } \ }; INST(Convolution); INST(ConvBiasForward); -- GitLab