diff --git a/lite/include/lite/network.h b/lite/include/lite/network.h index ecf75d97050acb7fae96a3944bd9ccefce6ca112..82c44dcb473b3cbb3144ef4b91fb6f222422f9db 100644 --- a/lite/include/lite/network.h +++ b/lite/include/lite/network.h @@ -373,6 +373,14 @@ public: //! dump network after global layout transform optimization static void dump_layout_transform_model( std::shared_ptr network, std::string optimized_model_path); + + //! get the model io information before model loaded by model path. + static NetworkIO get_model_io_info( + const std::string& model_path, const Config& config = {}); + + //! get the model io information before model loaded by model memory. + static NetworkIO get_model_io_info( + const void* model_mem, size_t size, const Config& config = {}); }; } // namespace lite diff --git a/lite/lite-c/include/lite-c/network_c.h b/lite/lite-c/include/lite-c/network_c.h index 592eba137a8790eca1daac63f2c2915213f2e0b8..7305842fac97b6ee91ae9143652fea6616861071 100644 --- a/lite/lite-c/include/lite-c/network_c.h +++ b/lite/lite-c/include/lite-c/network_c.h @@ -588,6 +588,28 @@ LITE_API int LITE_enable_global_layout_transform(LiteNetwork network); LITE_API int LITE_dump_layout_transform_model( LiteNetwork network, const char* dump_file_path); +/**! get the model io information before model loaded by model path. + * \param[in] model_path The model file path + * \param[in] config The model config for loading + * \param[out] ios The model io infermation + * \return int if the return is not zero, error happened, the error message + * can get by LITE_get_last_error + */ +LITE_API int LITE_get_model_io_info_by_path( + const char* model_path, const LiteConfig config, LiteNetworkIO* ios); + +/** get the model io information before model loaded by model memory. + * \param[in] model_mem The model memory ptr + * \param[in] size The model memory ptr length + * \param[in] config The model config for loading + * \param[out] ios The model io infermation + * \return int if the return is not zero, error happened, the error message + * can get by LITE_get_last_error + */ +LITE_API int LITE_get_model_io_info_by_memory( + const void* model_mem, size_t size, const LiteConfig config, + LiteNetworkIO* ios); + #ifdef __cplusplus } #endif diff --git a/lite/lite-c/src/network.cpp b/lite/lite-c/src/network.cpp index 6814c9cead6a2305caa2ac1fb054b91d735cf2ac..8072c5cd8a03c23ab1a28bbf0963baf053d037cf 100644 --- a/lite/lite-c/src/network.cpp +++ b/lite/lite-c/src/network.cpp @@ -167,6 +167,31 @@ lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) { return network_io; } +struct InnerIO { + std::vector names; + std::vector inputs; + std::vector outputs; +}; + +InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { + InnerIO innner_io; + for (size_t i = 0; i < network_io.inputs.size(); i++) { + lite::IO io = network_io.inputs[i]; + innner_io.names.push_back(io.name); + innner_io.inputs.push_back( + {innner_io.names.back().c_str(), io.is_host, io.io_type, + convert_to_clayout(io.config_layout)}); + } + for (size_t i = 0; i < network_io.outputs.size(); i++) { + lite::IO io = network_io.outputs[i]; + innner_io.names.push_back(io.name); + innner_io.outputs.push_back( + {innner_io.names.back().c_str(), io.is_host, io.io_type, + convert_to_clayout(io.config_layout)}); + } + return innner_io; +} + int LITE_make_default_network(LiteNetwork* network) { LITE_CAPI_BEGIN(); LITE_ASSERT(network, "The network pass to LITE api is null"); @@ -665,4 +690,59 @@ int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_ lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path); LITE_CAPI_END(); } + +namespace { +static LITE_MUTEX mtx_io; +static std::unordered_map& get_global_io_holder() { + static std::unordered_map global_holder; + return global_holder; +} + +int write_ios_from_cpp_io( + const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) { + LITE_CAPI_BEGIN(); + LITE_LOCK_GUARD(mtx_io); + get_global_io_holder()[key] = convert_to_inner_io(cpp_io); + auto&& inner_io = get_global_io_holder()[key]; + ios->input_size = inner_io.inputs.size(); + ios->output_size = inner_io.outputs.size(); + ios->inputs = inner_io.inputs.data(); + ios->outputs = inner_io.outputs.data(); + size_t i = 0; + for (; i < ios->input_size; i++) { + auto io_ptr = ios->inputs + i; + io_ptr->name = inner_io.names[i].c_str(); + } + for (; i < ios->output_size; i++) { + auto io_ptr = ios->outputs + i; + io_ptr->name = inner_io.names[i].c_str(); + } + LITE_CAPI_END(); +} + +} // namespace + +int LITE_get_model_io_info_by_path( + const char* model_path, const LiteConfig config, LiteNetworkIO* ios) { + LITE_CAPI_BEGIN(); + LITE_ASSERT(model_path, "The model_path pass to LITE api is null"); + auto&& cpp_ios = lite::Runtime::get_model_io_info( + std::string{model_path}, convert_to_lite_config(config)); + return write_ios_from_cpp_io( + cpp_ios, ios, reinterpret_cast(model_path)); + LITE_CAPI_END(); +} + +int LITE_get_model_io_info_by_memory( + const void* model_mem, size_t size, const LiteConfig config, + LiteNetworkIO* ios) { + LITE_CAPI_BEGIN(); + LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null"); + auto&& cpp_ios = lite::Runtime::get_model_io_info( + model_mem, size, convert_to_lite_config(config)); + return write_ios_from_cpp_io( + cpp_ios, ios, reinterpret_cast(model_mem)); + LITE_CAPI_END(); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/pylite/megenginelite/network.py b/lite/pylite/megenginelite/network.py index 6c64890eaf4b10775233bc6c67dbd82c40d49348..05a2b178e1fce728099bab02d9e187df04128a39 100644 --- a/lite/pylite/megenginelite/network.py +++ b/lite/pylite/megenginelite/network.py @@ -364,6 +364,14 @@ class _NetworkAPI(_LiteCObjBase): ("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]), ("LITE_enable_global_layout_transform", [_Cnetwork]), ("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]), + ( + "LITE_get_model_io_info_by_path", + [c_char_p, LiteConfig, POINTER(_LiteNetworkIO)], + ), + ( + "LITE_get_model_io_info_by_memory", + [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], + ), ] @@ -619,3 +627,27 @@ class LiteNetwork(object): def dump_layout_transform_model(self, model_file): c_file = model_file.encode("utf-8") self._api.LITE_dump_layout_transform_model(self._network, c_file) + + +def get_model_io_info(model_path, config=None): + """ + get the model IO information before create the NetWork, this IO + information can be used to configuration the NetWork. + """ + api = _NetworkAPI()._lib + c_path = c_char_p(model_path.encode("utf-8")) + + ios = _LiteNetworkIO() + + if config is not None: + api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) + else: + config = LiteConfig() + api.LITE_get_model_io_info_by_path(c_path, config, byref(ios)) + + ret_ios = LiteNetworkIO() + for i in range(ios.input_size): + ret_ios.add_input(ios.inputs[i]) + for i in range(ios.output_size): + ret_ios.add_output(ios.outputs[i]) + return ret_ios diff --git a/lite/pylite/test/test_utils.py b/lite/pylite/test/test_utils.py index 893cf03eea67cb6171e7ffa4548926505fa8bc2d..679c3077273b3c44b91d7525fa84aa5c14e379ee 100644 --- a/lite/pylite/test/test_utils.py +++ b/lite/pylite/test/test_utils.py @@ -8,6 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools +import os import numpy as np @@ -200,3 +201,20 @@ def test_tensor_collect_batch_device_numpy(): for i in range(4): for j in range(48): assert data[i][j // 8][j % 8] == i + 1 + + +def test_get_model_io_ahead(): + source_dir = os.getenv("LITE_TEST_RESOURCE") + model_path = os.path.join(source_dir, "shufflenet.mge") + ios = get_model_io_info(model_path) + + assert len(ios.inputs) == 1 + assert ios.inputs[0].name == "data" + assert ios.inputs[0].config_layout.shapes[1] == 3 + assert ios.inputs[0].config_layout.shapes[2] == 224 + assert ios.inputs[0].config_layout.shapes[3] == 224 + + assert len(ios.outputs) == 1 + assert ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]" + assert ios.outputs[0].config_layout.shapes[0] == 1 + assert ios.outputs[0].config_layout.shapes[1] == 1000 diff --git a/lite/src/function_base.h b/lite/src/function_base.h index b824fb7aabf9001071bca9f83951f1bf97b59afa..3db7d4443470e209f22517a8f7bf47bb4bf92a61 100644 --- a/lite/src/function_base.h +++ b/lite/src/function_base.h @@ -34,7 +34,7 @@ ADD_STATEMENT(NetworkImplDft, Dft); } // namespace // if it can't find the function, ignore -template +template ret_type try_call_func(std::string func_name, Args... args) { mark_used_variable(func_name); mark_used_variable(args...); @@ -42,10 +42,10 @@ ret_type try_call_func(std::string func_name, Args... args) { } // if it can't find the function, throw error -template +template ret_type call_func(std::string func_name, Args... args) { mark_used_variable(args...); - auto backend_name = class_type_name()(); + auto backend_name = class_type_name()(); auto msg_info = func_name + " is not aviliable in " + backend_name + " backend."; LITE_THROW(msg_info.c_str()); } diff --git a/lite/src/mge/function_dft.h b/lite/src/mge/function_dft.h index a4d107214d7ca5acdde73d73558ca2cfb612fc74..f44e132c03684b8be1a3e78f3d9ef9fddd7a3826 100644 --- a/lite/src/mge/function_dft.h +++ b/lite/src/mge/function_dft.h @@ -206,6 +206,26 @@ inline void call_func( THROW_FUNC_ERROR(func_name); } } + +template <> +inline NetworkIO call_func( + std::string func_name, std::string model_path, Config config) { + if (func_name == "get_model_io_info") { + return get_model_io_info_dft(model_path, config); + } else { + THROW_FUNC_ERROR(func_name); + } +} + +template <> +inline NetworkIO call_func( + std::string func_name, const void* model_mem, size_t size, Config config) { + if (func_name == "get_model_io_info") { + return get_model_io_info_dft(model_mem, size, config); + } else { + THROW_FUNC_ERROR(func_name); + } +} #undef THROW_FUNC_ERROR } // namespace lite diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 3b40b6659f8714edb10a40fed6a1bcd19bbcb5aa..b2f8558ff2591806368713c49ac7e64081129469 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -929,5 +929,75 @@ void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_pat "enable_global_layout_transform before")); } } + +NetworkIO lite::get_model_io_info_dft( + const std::string& model_path, const Config& config) { + FILE* fin = fopen(model_path.c_str(), "rb"); + LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); + fseek(fin, 0, SEEK_END); + size_t size = ftell(fin); + fseek(fin, 0, SEEK_SET); + void* ptr = malloc(size); + std::shared_ptr buf{ptr, ::free}; + auto nr = fread(buf.get(), 1, size, fin); + LITE_ASSERT(nr == size); + fclose(fin); + return get_model_io_info_dft(ptr, size, config); +} + +NetworkIO lite::get_model_io_info_dft( + const void* model_mem, size_t size, const Config& config) { + std::shared_ptr model{const_cast(model_mem), [](void*) {}}; + auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false); + auto format = + mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file); + if (!format.valid()) { + LITE_THROW("invalid model format"); + } + auto loader = + mgb::serialization::GraphLoader::make(std::move(input_file), format.val()); + + mgb::serialization::GraphLoadConfig load_config; + load_config.comp_graph = mgb::ComputingGraph::make(); + if (config.has_compression) { + load_config.tensor_value_loader = decompressed_tensor_value_loader; + } + auto compnode_locator = to_compnode_locator(config.device_type); + load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) { + if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { + loc.type = compnode_locator.type; + } + loc.device = compnode_locator.device; + }; + auto load_result = loader->load(load_config, true); + NetworkIO IOs; + for (auto&& in_tensor_iter : load_result.tensor_map) { + IO in_io; + in_io.name = in_tensor_iter.first; + in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout()); + IOs.inputs.push_back(in_io); + } + auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* { + auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager(); + using InferType = mgb::cg::static_infer::InferType; + if (static_infer_mgr.get_infer_type(var.node()).shape & + (InferType::CONST | InferType::RT_STATIC)) { + return static_infer_mgr.infer_shape_fallible(var.node()); + } else { + return nullptr; + } + }; + for (auto&& out : load_result.output_var_list) { + IO out_io; + out_io.name = out.node()->name(); + if (auto shape = infer_shape(out)) { + out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()}); + } else { + out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()}); + } + IOs.outputs.push_back(out_io); + } + return IOs; +} #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/src/mge/network_impl.h b/lite/src/mge/network_impl.h index 56e7a07552d5d9e40f53b2dfffd8f0e22b324831..78c5715da4f9296414017bed8b8c313c31c12b54 100644 --- a/lite/src/mge/network_impl.h +++ b/lite/src/mge/network_impl.h @@ -262,6 +262,13 @@ private: #endif std::unique_ptr m_iodump; }; +//! get the model information before model loaded by Network +NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config); + +//! get the model information before model loaded by Network by model memory and +//! size +NetworkIO get_model_io_info_dft( + const void* model_mem, size_t size, const Config& config); } // namespace lite diff --git a/lite/src/network.cpp b/lite/src/network.cpp index a4ffb4b58cb5f5d2d230259621dcba8e6f67f3f5..44cc8b4ed5c6807c218d7c5c132b2a10b361c67a 100644 --- a/lite/src/network.cpp +++ b/lite/src/network.cpp @@ -534,4 +534,26 @@ void Runtime::dump_layout_transform_model( LITE_THROW("dump_layout_transform_model is not aviliable in the backend."); LITE_ERROR_HANDLER_END } + +NetworkIO Runtime::get_model_io_info( + const std::string& model_path, const Config& config) { + LITE_ERROR_HANDLER_BEGIN + if (config.backend == LiteBackend::LITE_DEFAULT) { + return call_func( + "get_model_io_info", model_path, config); + } + LITE_THROW("get_model_io_info is not aviliable in the backend."); + LITE_ERROR_HANDLER_END +} + +NetworkIO Runtime::get_model_io_info( + const void* model_mem, size_t size, const Config& config) { + LITE_ERROR_HANDLER_BEGIN + if (config.backend == LiteBackend::LITE_DEFAULT) { + return call_func( + "get_model_io_info", model_mem, size, config); + } + LITE_THROW("get_model_io_info is not aviliable in the backend."); + LITE_ERROR_HANDLER_END +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 8734e8ee7a7a98167bf3827775a21b95dbc54ae7..78f139d5fb669eeb218164b13fce4b7c145df766 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -106,6 +106,54 @@ TEST(TestNetWork, GetAllName) { ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); } +TEST(TestNetWork, GetAllIoInfoAhead) { + Config config; + std::string model_path = "./shufflenet.mge"; + + auto ios = Runtime::get_model_io_info(model_path); + + FILE* fin = fopen(model_path.c_str(), "rb"); + ASSERT_TRUE(fin); + fseek(fin, 0, SEEK_END); + size_t size = ftell(fin); + fseek(fin, 0, SEEK_SET); + void* ptr = malloc(size); + std::shared_ptr buf{ptr, ::free}; + auto nr = fread(buf.get(), 1, size, fin); + LITE_ASSERT(nr == size); + fclose(fin); + + auto ios_mem = Runtime::get_model_io_info(ptr, size); + + ASSERT_EQ(ios.inputs.size(), ios_mem.inputs.size()); + ASSERT_EQ(ios.inputs.size(), 1); + + ASSERT_EQ(ios.outputs.size(), ios_mem.outputs.size()); + ASSERT_EQ(ios.outputs.size(), 1); + + ASSERT_TRUE(ios.inputs[0].name == "data"); + ASSERT_TRUE(ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); + + ASSERT_TRUE(ios_mem.inputs[0].name == "data"); + ASSERT_TRUE( + ios_mem.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); + ASSERT_EQ(ios.inputs[0].config_layout.ndim, 4); + ASSERT_EQ(ios.inputs[0].config_layout.shapes[1], 3); + ASSERT_EQ(ios.inputs[0].config_layout.shapes[2], 224); + + ASSERT_EQ(ios.outputs[0].config_layout.ndim, 2); + ASSERT_EQ(ios.outputs[0].config_layout.shapes[0], 1); + ASSERT_EQ(ios.outputs[0].config_layout.shapes[1], 1000); + + ASSERT_EQ(ios_mem.inputs[0].config_layout.ndim, 4); + ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[1], 3); + ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[2], 224); + + ASSERT_EQ(ios_mem.outputs[0].config_layout.ndim, 2); + ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[0], 1); + ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[1], 1000); +} + TEST(TestNetWork, LoadFBSModel) { Config config; std::string model_path = "./ax.mge"; diff --git a/lite/test/test_network_c.cpp b/lite/test/test_network_c.cpp index 591e5db60c0b51b3b7c46d3c4167c31901c2d007..d34f6c3a33d42a7c959515c7ec6dab1e813396ed 100644 --- a/lite/test/test_network_c.cpp +++ b/lite/test/test_network_c.cpp @@ -252,6 +252,55 @@ TEST(TestCapiNetWork, GetAllName) { LITE_destroy_network(c_network); } +TEST(TestCapiNetWork, GetAllNameAhead) { + std::string model_path = "./shufflenet.mge"; + LiteNetworkIO ios, ios_mem; + LITE_CAPI_CHECK(LITE_get_model_io_info_by_path( + model_path.c_str(), *default_config(), &ios)); + FILE* fin = fopen(model_path.c_str(), "rb"); + ASSERT_TRUE(fin); + fseek(fin, 0, SEEK_END); + size_t size = ftell(fin); + fseek(fin, 0, SEEK_SET); + void* ptr = malloc(size); + std::shared_ptr buf{ptr, ::free}; + auto nr = fread(buf.get(), 1, size, fin); + LITE_ASSERT(nr == size); + fclose(fin); + + LITE_CAPI_CHECK( + LITE_get_model_io_info_by_memory(ptr, size, *default_config(), &ios_mem)); + + ASSERT_EQ(ios.input_size, 1); + ASSERT_EQ(ios.output_size, 1); + ASSERT_EQ(ios_mem.input_size, 1); + ASSERT_EQ(ios_mem.output_size, 1); + + ASSERT_TRUE(std::string(ios.inputs->name) == "data"); + ASSERT_TRUE(ios.inputs->config_layout.ndim == 4); + ASSERT_TRUE(ios.inputs->config_layout.shapes[1] == 3); + ASSERT_TRUE(ios.inputs->config_layout.shapes[2] == 224); + ASSERT_TRUE(ios.inputs->config_layout.shapes[3] == 224); + ASSERT_TRUE( + std::string(ios.outputs->name) == + "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); + ASSERT_TRUE(ios.outputs->config_layout.ndim == 2); + ASSERT_TRUE(ios.outputs->config_layout.shapes[0] == 1); + ASSERT_TRUE(ios.outputs->config_layout.shapes[1] == 1000); + + ASSERT_TRUE(std::string(ios_mem.inputs->name) == "data"); + ASSERT_TRUE(ios_mem.inputs->config_layout.ndim == 4); + ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[1] == 3); + ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[2] == 224); + ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[3] == 224); + ASSERT_TRUE( + std::string(ios_mem.outputs->name) == + "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"); + ASSERT_TRUE(ios_mem.outputs->config_layout.ndim == 2); + ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[0] == 1); + ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000); +} + #if LITE_BUILD_WITH_RKNPU static int GetTop(