提交 f7e034b5 编写于 作者: M Megvii Engine Team

feat(lite): add global layout transform python interface for lite

GitOrigin-RevId: f159f492082911a366dc5fe2f60f92af290636f4
上级 e70c07a2
...@@ -97,7 +97,7 @@ struct LITE_API Options { ...@@ -97,7 +97,7 @@ struct LITE_API Options {
bool no_profiling_on_shape_change = false; bool no_profiling_on_shape_change = false;
uint8_t jit_level = 0; uint8_t jit_level = 0;
uint8_t comp_node_seq_record_level = 0; uint8_t comp_node_seq_record_level = 0;
uint8_t graph_opt_level = 0; uint8_t graph_opt_level = 2;
uint16_t async_exec_level = 1; uint16_t async_exec_level = 1;
//! layout transform options //! layout transform options
...@@ -368,7 +368,6 @@ public: ...@@ -368,7 +368,6 @@ public:
const std::shared_ptr<Network> src_network); const std::shared_ptr<Network> src_network);
//! set global layout transform optimization for network //! set global layout transform optimization for network
static void enable_global_layout_transform(std::shared_ptr<Network> network); static void enable_global_layout_transform(std::shared_ptr<Network> network);
//! dump network after global layout transform optimization //! dump network after global layout transform optimization
......
...@@ -362,6 +362,8 @@ class _NetworkAPI(_LiteCObjBase): ...@@ -362,6 +362,8 @@ class _NetworkAPI(_LiteCObjBase):
("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]), ("LITE_set_start_callback", [_Cnetwork, LiteStartCallback]),
("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]), ("LITE_set_finish_callback", [_Cnetwork, LiteFinishCallback]),
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
("LITE_enable_global_layout_transform", [_Cnetwork]),
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
] ]
...@@ -610,3 +612,10 @@ class LiteNetwork(object): ...@@ -610,3 +612,10 @@ class LiteNetwork(object):
def get_static_memory_alloc_info(self, log_dir="logs/test"): def get_static_memory_alloc_info(self, log_dir="logs/test"):
c_log_dir = log_dir.encode("utf-8") c_log_dir = log_dir.encode("utf-8")
self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir) self._api.LITE_get_static_memory_alloc_info(self._network, c_log_dir)
def enable_global_layout_transform(self):
self._api.LITE_enable_global_layout_transform(self._network)
def dump_layout_transform_model(self, model_file):
c_file = model_file.encode("utf-8")
self._api.LITE_dump_layout_transform_model(self._network, c_file)
...@@ -451,3 +451,20 @@ class TestNetwork(TestShuffleNet): ...@@ -451,3 +451,20 @@ class TestNetwork(TestShuffleNet):
network.wait() network.wait()
self.check_correct(out_array) self.check_correct(out_array)
def test_enable_global_layout_transform(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
def test_dump_layout_transform_model(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
self.do_forward(network)
fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")
...@@ -272,3 +272,22 @@ class TestNetwork(TestShuffleNetCuda): ...@@ -272,3 +272,22 @@ class TestNetwork(TestShuffleNetCuda):
| LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE | LiteAlgoSelectStrategy.LITE_ALGO_REPRODUCIBLE
) )
self.do_forward(network) self.do_forward(network)
@require_cuda()
def test_enable_global_layout_transform(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
@require_cuda()
def test_dump_layout_transform_model(self):
network = LiteNetwork()
network.enable_global_layout_transform()
network.load(self.model_path)
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
self.do_forward(network)
fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")
...@@ -406,7 +406,7 @@ void NetworkImplDft::load_model( ...@@ -406,7 +406,7 @@ void NetworkImplDft::load_model(
use_tensorrt(); use_tensorrt();
} }
m_load_result = m_loader->load(m_load_config, false); m_load_result = m_loader->load(m_load_config, true);
global_layout_transform(); global_layout_transform();
......
...@@ -910,7 +910,6 @@ TEST(TestNetWork, LoadPackedModel) { ...@@ -910,7 +910,6 @@ TEST(TestNetWork, LoadPackedModel) {
} }
TEST(TestNetWork, GlabalLayoutTransform) { TEST(TestNetWork, GlabalLayoutTransform) {
// set_log_level(LiteLogLevel::DEBUG);
auto tensor = get_input_data("./input_data.npy"); auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge"; std::string model_path = "./shufflenet.mge";
std::string input_name = "data"; std::string input_name = "data";
...@@ -931,6 +930,7 @@ TEST(TestNetWork, GlabalLayoutTransform) { ...@@ -931,6 +930,7 @@ TEST(TestNetWork, GlabalLayoutTransform) {
network->forward(); network->forward();
network->wait(); network->wait();
ASSERT_TRUE(fopen(dump_model_name.c_str(), "r")); ASSERT_TRUE(fopen(dump_model_name.c_str(), "r"));
remove(dump_model_name.c_str());
} }
TEST(TestNetWork, GetDeviceType) { TEST(TestNetWork, GetDeviceType) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册