diff --git a/lite/include/lite/network.h b/lite/include/lite/network.h index 9be0cecc8ddaa46217c3efb84a258d3e7e9105f5..ecf75d97050acb7fae96a3944bd9ccefce6ca112 100644 --- a/lite/include/lite/network.h +++ b/lite/include/lite/network.h @@ -97,7 +97,7 @@ struct LITE_API Options { bool no_profiling_on_shape_change = false; uint8_t jit_level = 0; uint8_t comp_node_seq_record_level = 0; - uint8_t graph_opt_level = 0; + uint8_t graph_opt_level = 2; uint16_t async_exec_level = 1; //! layout transform options @@ -368,7 +368,6 @@ public: const std::shared_ptr src_network); //! set global layout transform optimization for network - static void enable_global_layout_transform(std::shared_ptr network); //! dump network after global layout transform optimization diff --git a/lite/pylite/megenginelite/network.py b/lite/pylite/megenginelite/network.py index c8072791349b8cf090d0b57f20124d2770583db4..6c64890eaf4b10775233bc6c67dbd82c40d49348 100644 --- a/lite/pylite/megenginelite/network.py +++ b/lite/pylite/megenginelite/network.py @@ -362,6 +362,8 @@ class _NetworkAPI(_LiteCObjBase): ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), + ("LITE_enable_global_layout_transform", [_Cnetwork]), + ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), ] @@ -610,3 +612,10 @@ class LiteNetwork(object): def get_static_memory_alloc_info(self, log_dir="logs/test"): c_log_dir = log_dir.encode("utf-8") self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) + + def enable_global_layout_transform(self): + self._api.LITE_enable_global_layout_transform(self._network) + + def dump_layout_transform_model(self, model_file): + c_file = model_file.encode("utf-8") + self._api.LITE_dump_layout_transform_model(self._network, c_file) diff --git a/lite/pylite/test/test_network.py b/lite/pylite/test/test_network.py index 70d4aecf203c21e5c6b6173b8b8de3871f4eea5d..aaeea5e448a1ee8759a390ef4f957d37792e92b4 100644 --- a/lite/pylite/test/test_network.py +++ b/lite/pylite/test/test_network.py @@ -451,3 +451,20 @@ class TestNetwork(TestShuffleNet): network.wait() self.check_correct(out_array) + + def test_enable_global_layout_transform(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + self.do_forward(network) + + def test_dump_layout_transform_model(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") + self.do_forward(network) + + fi = open("./model_afer_layoutTrans.mgb", "r") + fi.close() + os.remove("./model_afer_layoutTrans.mgb") diff --git a/lite/pylite/test/test_network_cuda.py b/lite/pylite/test/test_network_cuda.py index 56e74b247d923367a3d44bbb9e8e4d91634cb2d8..177b934047b0da1f6586f917aef9d674e7c796e5 100644 --- a/lite/pylite/test/test_network_cuda.py +++ b/lite/pylite/test/test_network_cuda.py @@ -272,3 +272,22 @@ class TestNetwork(TestShuffleNetCuda): | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE ) self.do_forward(network) + + @require_cuda() + def test_enable_global_layout_transform(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + self.do_forward(network) + + @require_cuda() + def test_dump_layout_transform_model(self): + network = LiteNetwork() + network.enable_global_layout_transform() + network.load(self.model_path) + network.dump_layout_transform_model("./model_afer_layoutTrans.mgb") + self.do_forward(network) + + fi = open("./model_afer_layoutTrans.mgb", "r") + fi.close() + os.remove("./model_afer_layoutTrans.mgb") diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 74c30abd321299772c75208ea5f29c5b8e5732f4..3b40b6659f8714edb10a40fed6a1bcd19bbcb5aa 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -406,7 +406,7 @@ void NetworkImplDft::load_model( use_tensorrt(); } - m_load_result = m_loader->load(m_load_config, false); + m_load_result = m_loader->load(m_load_config, true); global_layout_transform(); diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 5b9334d88be4c6d35033307c8ab58c8e0f341006..c7cab766a807a705141ea3393be8c842bd27965e 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -910,7 +910,6 @@ TEST(TestNetWork, LoadPackedModel) { } TEST(TestNetWork, GlabalLayoutTransform) { - // set_log_level(LiteLogLevel::DEBUG); auto tensor = get_input_data("./input_data.npy"); std::string model_path = "./shufflenet.mge"; std::string input_name = "data"; @@ -931,6 +930,7 @@ TEST(TestNetWork, GlabalLayoutTransform) { network->forward(); network->wait(); ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); + remove(dump_model_name.c_str()); } TEST(TestNetWork, GetDeviceType) {