From f67086adde310957afe19de929f70b24d535c201 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 28 Feb 2022 10:56:33 +0800 Subject: [PATCH] fix(lite): fix lite global layout transform symvar replace error GitOrigin-RevId: 7ac74a596ac01a0a76ee0fca25b29bdc5722e381 --- lite/pylite/test/test_network_device.py | 6 ++++-- lite/src/mge/network_impl.cpp | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/lite/pylite/test/test_network_device.py b/lite/pylite/test/test_network_device.py index 177b93404..c240cb128 100644 --- a/lite/pylite/test/test_network_device.py +++ b/lite/pylite/test/test_network_device.py @@ -275,14 +275,16 @@ class TestNetwork(TestShuffleNetCuda): @require_cuda() def test_enable_global_layout_transform(self): - network = LiteNetwork() + config_ = LiteConfig(device_type=LiteDeviceType.LITE_CUDA) + network = LiteNetwork(config=config_) network.enable_global_layout_transform() network.load(self.model_path) self.do_forward(network) @require_cuda() def test_dump_layout_transform_model(self): - network = LiteNetwork() + config_ = LiteConfig(device_type=LiteDeviceType.LITE_CUDA) + network = LiteNetwork(config=config_) network.enable_global_layout_transform() network.load(self.model_path) network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index b2f8558ff..04ec582a8 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -365,8 +365,22 @@ void NetworkImplDft::adapt_option_valid() { void NetworkImplDft::global_layout_transform() { if (m_set_layout_transform) { - m_load_result.output_var_list = mgb::gopt::layout_transform( + mgb::ThinHashMap out_var_map; + auto output_var_array = mgb::gopt::layout_transform( m_load_result.output_var_list, m_layout_transform_target); + // replace symvar in output_var_list + for (size_t idx = 0; idx < output_var_array.size(); ++idx) { + out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx]; + m_load_result.output_var_list[idx] = output_var_array[idx]; + } + // replace symvar in output_var_map_id + for (auto&& item : m_load_result.output_var_map_id) { + item.second = out_var_map[item.second]; + } + // replace symvar in output_var_map + for (auto&& item : m_load_result.output_var_map) { + item.second = out_var_map[item.second]; + } } } -- GitLab