From 8bdcf6b5a694ef42b26bd91fa4ecad6ad622e00d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 15 Oct 2021 16:32:42 +0800 Subject: [PATCH] feat(lite): add get static mem info function in lite c++ GitOrigin-RevId: 8c9e42a74409381d17601bdb0bac7dd08483193a --- lite/include/lite/network.h | 3 +++ lite/src/mge/network_impl.cpp | 10 ++++++++++ lite/src/mge/network_impl.h | 4 ++++ lite/src/network.cpp | 14 ++++++++++++++ lite/src/network_impl_base.h | 8 ++++++++ lite/test/test_network.cpp | 29 +++++++++++++++++++++++++++++ 6 files changed, 68 insertions(+) diff --git a/lite/include/lite/network.h b/lite/include/lite/network.h index bfde5dea9..1bf26b824 100644 --- a/lite/include/lite/network.h +++ b/lite/include/lite/network.h @@ -282,6 +282,9 @@ public: //! get device type LiteDeviceType get_device_type() const; + //! get static peak memory info showed by Graph visualization + void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; + public: friend class NetworkHelper; diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index a64c8bd0b..e75d598ab 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -778,6 +778,16 @@ void inline NetworkImplDft::output_plugin_result() const { } #endif } + +void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const { +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON + m_execute_func->get_static_memory_alloc_info(log_dir); + return; +#endif #endif + LITE_MARK_USED_VAR(log_dir); +} +#endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/src/mge/network_impl.h b/lite/src/mge/network_impl.h index d370253c3..04af5f9f2 100644 --- a/lite/src/mge/network_impl.h +++ b/lite/src/mge/network_impl.h @@ -163,6 +163,10 @@ public: //! directory, in binary format void enable_io_bin_dump(std::string io_bin_out_dir); + //! get static peak memory info showed by Graph visualization + void get_static_memory_alloc_info( + const std::string& log_dir = "logs/test") const override; + private: //! construct the outputspec according to the m_network_io, and set the //! call_back to the outputspec diff --git a/lite/src/network.cpp b/lite/src/network.cpp index 7c69b8645..745dbe5d2 100644 --- a/lite/src/network.cpp +++ b/lite/src/network.cpp @@ -283,6 +283,20 @@ LiteDeviceType Network::get_device_type() const { LITE_ERROR_HANDLER_END } +void Network::get_static_memory_alloc_info(const std::string& log_dir) const { + LITE_ERROR_HANDLER_BEGIN +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON + LITE_ASSERT(m_loaded, "get_all_output_name should be used after model loaded."); + m_impl->get_static_memory_alloc_info(log_dir); + return; +#endif +#endif + LITE_MARK_USED_VAR(log_dir); + LITE_THROW("Doesn't support get_static_memory_alloc_info().Please check macro."); + LITE_ERROR_HANDLER_END +} + /*********************** MGE special network function ***************/ void Runtime::set_cpu_threads_number( diff --git a/lite/src/network_impl_base.h b/lite/src/network_impl_base.h index 234beb66a..e6a05bb01 100644 --- a/lite/src/network_impl_base.h +++ b/lite/src/network_impl_base.h @@ -125,6 +125,14 @@ public: //! enable profile the network, a file will be generated virtual void enable_profile_performance(std::string profile_file_path) = 0; + + //! get static peak memory info showed by Graph visualization + virtual void get_static_memory_alloc_info(const std::string& log_dir) const { + LITE_MARK_USED_VAR(log_dir); + LITE_THROW( + "This nerworkimpl doesn't support get_static_memory_alloc_info() " + "function."); + } }; /******************************** friend class *****************************/ diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 9ddfa48bf..617f8055e 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -646,6 +646,35 @@ TEST(TestNetWork, GetModelExtraInfo) { printf("extra_info %s \n", extra_info.c_str()); } +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON +TEST(TestNetWork, GetMemoryInfo) { + Config config; + auto lite_tensor = get_input_data("./input_data.npy"); + std::string model_path = "./shufflenet.mge"; + + auto result_mgb = mgb_lar(model_path, config, "data", lite_tensor); + + std::shared_ptr network = std::make_shared(config); + Runtime::set_cpu_threads_number(network, 2); + + network->load_model(model_path); + network->get_static_memory_alloc_info(); + std::shared_ptr input_tensor = network->get_input_tensor(0); + + auto src_ptr = lite_tensor->get_memory_ptr(); + auto src_layout = lite_tensor->get_layout(); + input_tensor->reset(src_ptr, src_layout); + + network->forward(); + network->wait(); + std::shared_ptr output_tensor = network->get_output_tensor(0); + + compare_lite_tensor(output_tensor, result_mgb); +} +#endif +#endif + #if LITE_WITH_CUDA TEST(TestNetWork, BasicDevice) { -- GitLab