From f7e034b5067ccbcbc53c9853b5d97f378c940818 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 31 Dec 2021 16:18:37 +0800 Subject: [PATCH] feat(lite): add global layout transform python interface for lite GitOrigin-RevId: f159f492082911a366dc5fe2f60f92af290636f4 --- lite/include/lite/network.h | 3 +-- lite/pylite/megenginelite/network.py | 9 +++++++++ lite/pylite/test/test_network.py | 17 +++++++++++++++++ lite/pylite/test/test_network_cuda.py | 19 +++++++++++++++++++ lite/src/mge/network_impl.cpp | 2 +- lite/test/test_network.cpp | 2 +- 6 files changed, 48 insertions(+), 4 deletions(-) diff --git a/lite/include/lite/network.h b/lite/include/lite/network.h index 9be0cecc8..ecf75d970 100644 --- a/lite/include/lite/network.h +++ b/lite/include/lite/network.h @@ -97,7 +97,7 @@ struct LITE_API Options { bool no_profiling_on_shape_change = false; uint8_t jit_level = 0; uint8_t comp_node_seq_record_level = 0; - uint8_t graph_opt_level = 0; + uint8_t graph_opt_level = 2; uint16_t async_exec_level = 1; //! layout transform options @@ -368,7 +368,6 @@ public: const std::shared_ptr src_network); //! set global layout transform optimization for network - static void enable_global_layout_transform(std::shared_ptr network); //! dump network after global layout transform optimization diff --git a/lite/pylite/megenginelite/network.py b/lite/pylite/megenginelite/network.py index c80727913..6c64890ea 100644 --- a/lite/pylite/megenginelite/network.py +++ b/lite/pylite/megenginelite/network.py @@ -362,6 +362,8 @@ class _NetworkAPI(_LiteCObjBase): ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), + ("LITE_enable_global_layout_transform", [_Cnetwork]), + ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), ] @@ -610,3 +612,10 @@ class LiteNetwork(object): def get_static_memory_alloc_info(self, log_dir="logs/test"): c_log_dir = log_dir.encode("utf-8") self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) + + def enable_global_layout_transform(self): + self._api.LITE_enable_global_layout_transform(self._network) + + def dump_layout_transform_model(self, model_file): + c_file = model_file.encode("utf-8") + self._api.LITE_dump_layout_transform_model(self._network, c_file) diff --git a/lite/pylite/test/test_network.py b/lite/pylite/test/test_network.py index 70d4aecf2..aaeea5e44 100644 --- a/lite/pylite/test/test_network.py +++ b/lite/pylite/test/test_network.py @@ -451,3 +451,20 @@ class TestNetwork(TestShuffleNet): network.wait() self.check_correct(out_array) + + def test_enable_global_layout_transform(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + self.do_forward(network) + + def test_dump_layout_transform_model(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") + self.do_forward(network) + + fi = open("./model_afer_layoutTrans.mgb", "r") + fi.close() + os.remove("./model_afer_layoutTrans.mgb") diff --git a/lite/pylite/test/test_network_cuda.py b/lite/pylite/test/test_network_cuda.py index 56e74b247..177b93404 100644 --- a/lite/pylite/test/test_network_cuda.py +++ b/lite/pylite/test/test_network_cuda.py @@ -272,3 +272,22 @@ class TestNetwork(TestShuffleNetCuda): | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE ) self.do_forward(network) + + @require_cuda() + def test_enable_global_layout_transform(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + self.do_forward(network) + + @require_cuda() + def test_dump_layout_transform_model(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") + self.do_forward(network) + + fi = open("./model_afer_layoutTrans.mgb", "r") + fi.close() + os.remove("./model_afer_layoutTrans.mgb") diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 74c30abd3..3b40b6659 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -406,7 +406,7 @@ void NetworkImplDft::load_model( use_tensorrt(); } - m_load_result = m_loader->load(m_load_config, false); + m_load_result = m_loader->load(m_load_config, true); global_layout_transform(); diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 5b9334d88..c7cab766a 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -910,7 +910,6 @@ TEST(TestNetWork, LoadPackedModel) { } TEST(TestNetWork, GlabalLayoutTransform) { - // set_log_level(LiteLogLevel::DEBUG); auto tensor = get_input_data("./input_data.npy"); std::string model_path = "./shufflenet.mge"; std::string input_name = "data"; @@ -931,6 +930,7 @@ TEST(TestNetWork, GlabalLayoutTransform) { network->forward(); network->wait(); ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); + remove(dump_model_name.c_str()); } TEST(TestNetWork, GetDeviceType) { -- GitLab