提交 8bdcf6b5 编写于 作者: M Megvii Engine Team

feat(lite): add get static mem info function in lite c++

GitOrigin-RevId: 8c9e42a74409381d17601bdb0bac7dd08483193a
上级 b84d2893
......@@ -282,6 +282,9 @@ public:
//! get device type
LiteDeviceType get_device_type() const;
//! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const;
public:
friend class NetworkHelper;
......
......@@ -778,6 +778,16 @@ void inline NetworkImplDft::output_plugin_result() const {
}
#endif
}
void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
m_execute_func->get_static_memory_alloc_info(log_dir);
return;
#endif
#endif
LITE_MARK_USED_VAR(log_dir);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -163,6 +163,10 @@ public:
//! directory, in binary format
void enable_io_bin_dump(std::string io_bin_out_dir);
//! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(
const std::string& log_dir = "logs/test") const override;
private:
//! construct the outputspec according to the m_network_io, and set the
//! call_back to the outputspec
......
......@@ -283,6 +283,20 @@ LiteDeviceType Network::get_device_type() const {
LITE_ERROR_HANDLER_END
}
void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_BEGIN
#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
LITE_ASSERT(m_loaded, "get_all_output_name should be used after model loaded.");
m_impl->get_static_memory_alloc_info(log_dir);
return;
#endif
#endif
LITE_MARK_USED_VAR(log_dir);
LITE_THROW("Doesn't support get_static_memory_alloc_info().Please check macro.");
LITE_ERROR_HANDLER_END
}
/*********************** MGE special network function ***************/
void Runtime::set_cpu_threads_number(
......
......@@ -125,6 +125,14 @@ public:
//! enable profile the network, a file will be generated
virtual void enable_profile_performance(std::string profile_file_path) = 0;
//! get static peak memory info showed by Graph visualization
virtual void get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_MARK_USED_VAR(log_dir);
LITE_THROW(
"This nerworkimpl doesn't support get_static_memory_alloc_info() "
"function.");
}
};
/******************************** friend class *****************************/
......
......@@ -646,6 +646,35 @@ TEST(TestNetWork, GetModelExtraInfo) {
printf("extra_info %s \n", extra_info.c_str());
}
#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
TEST(TestNetWork, GetMemoryInfo) {
Config config;
auto lite_tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
auto result_mgb = mgb_lar(model_path, config, "data", lite_tensor);
std::shared_ptr<Network> network = std::make_shared<Network>(config);
Runtime::set_cpu_threads_number(network, 2);
network->load_model(model_path);
network->get_static_memory_alloc_info();
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
auto src_ptr = lite_tensor->get_memory_ptr();
auto src_layout = lite_tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
network->forward();
network->wait();
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
compare_lite_tensor<float>(output_tensor, result_mgb);
}
#endif
#endif
#if LITE_WITH_CUDA
TEST(TestNetWork, BasicDevice) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册