提交 485e56ca 编写于 作者: M Megvii Engine Team

fix(lite): fix lite get model io info name error

GitOrigin-RevId: 904d86831d1a9dc6ebefb1c59463976e7667a339
上级 a450d0f5
......@@ -170,22 +170,24 @@ struct InnerIO {
};
InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
InnerIO innner_io;
for (size_t i = 0; i < network_io.inputs.size(); i++) {
InnerIO inner_io;
size_t id = 0;
inner_io.names.resize(network_io.inputs.size() + network_io.outputs.size());
for (size_t i = 0; i < network_io.inputs.size(); i++, id++) {
lite::IO io = network_io.inputs[i];
innner_io.names.push_back(io.name);
innner_io.inputs.push_back(
{innner_io.names.back().c_str(), io.is_host, io.io_type,
inner_io.names[id] = io.name;
inner_io.inputs.push_back(
{inner_io.names[id].c_str(), io.is_host, io.io_type,
convert_to_clayout(io.config_layout)});
}
for (size_t i = 0; i < network_io.outputs.size(); i++) {
for (size_t i = 0; i < network_io.outputs.size(); i++, id++) {
lite::IO io = network_io.outputs[i];
innner_io.names.push_back(io.name);
innner_io.outputs.push_back(
{innner_io.names.back().c_str(), io.is_host, io.io_type,
inner_io.names[id] = io.name;
inner_io.outputs.push_back(
{inner_io.names[id].c_str(), io.is_host, io.io_type,
convert_to_clayout(io.config_layout)});
}
return innner_io;
return inner_io;
}
lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) {
......@@ -727,15 +729,6 @@ int write_ios_from_cpp_io(
ios->output_size = inner_io.outputs.size();
ios->inputs = inner_io.inputs.data();
ios->outputs = inner_io.outputs.data();
size_t i = 0;
for (; i < ios->input_size; i++) {
auto io_ptr = ios->inputs + i;
io_ptr->name = inner_io.names[i].c_str();
}
for (; i < ios->output_size; i++) {
auto io_ptr = ios->outputs + i;
io_ptr->name = inner_io.names[i].c_str();
}
LITE_CAPI_END();
}
......
......@@ -78,3 +78,19 @@ class TestGlobal(TestShuffleNet):
phy_ptr2 = LiteGlobal.lookup_physic_ptr(vir_ptr, LiteDeviceType.LITE_AX)
assert phy_ptr.value == phy_ptr2.value
LiteGlobal.clear_memory_pair(vir_ptr, phy_ptr, LiteDeviceType.LITE_AX)
def test_get_model_io_info():
source_dir = os.getenv("LITE_TEST_RESOURCE")
model_path = os.path.join(source_dir, "./ax_models/77-fcf1a1af.axe")
model_io = get_model_io_info(model_path)
input_names = [in_node.name for in_node in model_io.inputs]
output_names = [out_node.name for out_node in model_io.outputs]
assert "op_4121.hardware" in input_names
assert "op_4238.hardware" in input_names
assert "op_4218:u2s.hardware" in input_names
assert "op_5034:add" in output_names
assert "op_5035:add" in output_names
assert "op_5036:add" in output_names
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册