提交 f67086ad 编写于 作者: M Megvii Engine Team

fix(lite): fix lite global layout transform symvar replace error

GitOrigin-RevId: 7ac74a596ac01a0a76ee0fca25b29bdc5722e381
上级 4462953f
......@@ -275,14 +275,16 @@ class TestNetwork(TestShuffleNetCuda):
@require_cuda()
def test_enable_global_layout_transform(self):
network = LiteNetwork()
config_ = LiteConfig(device_type=LiteDeviceType.LITE_CUDA)
network = LiteNetwork(config=config_)
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
@require_cuda()
def test_dump_layout_transform_model(self):
network = LiteNetwork()
config_ = LiteConfig(device_type=LiteDeviceType.LITE_CUDA)
network = LiteNetwork(config=config_)
network.enable_global_layout_transform()
network.load(self.model_path)
network.dump_layout_transform_model("./model_afer_layoutTrans.mgb")
......
......@@ -365,8 +365,22 @@ void NetworkImplDft::adapt_option_valid() {
void NetworkImplDft::global_layout_transform() {
if (m_set_layout_transform) {
m_load_result.output_var_list = mgb::gopt::layout_transform(
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
auto output_var_array = mgb::gopt::layout_transform(
m_load_result.output_var_list, m_layout_transform_target);
// replace symvar in output_var_list
for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx];
m_load_result.output_var_list[idx] = output_var_array[idx];
}
// replace symvar in output_var_map_id
for (auto&& item : m_load_result.output_var_map_id) {
item.second = out_var_map[item.second];
}
// replace symvar in output_var_map
for (auto&& item : m_load_result.output_var_map) {
item.second = out_var_map[item.second];
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册