提交 e70c07a2 编写于 作者: M Megvii Engine Team

feat(lite): add global layout transform c/c++ interface for lite

GitOrigin-RevId: 36a4b26b42079611f38c9f817165a26269749ec3
上级 86ee4638
...@@ -97,7 +97,7 @@ struct LITE_API Options { ...@@ -97,7 +97,7 @@ struct LITE_API Options {
bool no_profiling_on_shape_change = false; bool no_profiling_on_shape_change = false;
uint8_t jit_level = 0; uint8_t jit_level = 0;
uint8_t comp_node_seq_record_level = 0; uint8_t comp_node_seq_record_level = 0;
uint8_t graph_opt_level = 2; uint8_t graph_opt_level = 0;
uint16_t async_exec_level = 1; uint16_t async_exec_level = 1;
//! layout transform options //! layout transform options
...@@ -366,6 +366,14 @@ public: ...@@ -366,6 +366,14 @@ public:
static void shared_weight_with_network( static void shared_weight_with_network(
std::shared_ptr<Network> dst_network, std::shared_ptr<Network> dst_network,
const std::shared_ptr<Network> src_network); const std::shared_ptr<Network> src_network);
//! set global layout transform optimization for network
static void enable_global_layout_transform(std::shared_ptr<Network> network);
//! dump network after global layout transform optimization
static void dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path);
}; };
} // namespace lite } // namespace lite
......
...@@ -572,6 +572,22 @@ LITE_API int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out ...@@ -572,6 +572,22 @@ LITE_API int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out
LITE_API int LITE_get_static_memory_alloc_info( LITE_API int LITE_get_static_memory_alloc_info(
LiteNetwork network, const char* log_dir); LiteNetwork network, const char* log_dir);
/**
* \brief enable the global layout transform optimization
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_enable_global_layout_transform(LiteNetwork network);
/**
* \brief dump the model after the global layout transform optimization
* \param[in] dump_file_path The model file path need to dump
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_dump_layout_transform_model(
LiteNetwork network, const char* dump_file_path);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -648,4 +648,21 @@ int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir) ...@@ -648,4 +648,21 @@ int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir)
LITE_CAPI_END(); LITE_CAPI_END();
} }
int LITE_enable_global_layout_transform(LiteNetwork network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
std::shared_ptr<lite::Network> network_shared{
static_cast<lite::Network*>(network), [](void*) {}};
lite::Runtime::enable_global_layout_transform(network_shared);
LITE_CAPI_END();
}
int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_path) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
std::shared_ptr<lite::Network> network_shared{
static_cast<lite::Network*>(network), [](void*) {}};
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
LITE_CAPI_END();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -121,6 +121,8 @@ inline void call_func<NetworkImplDft, void>( ...@@ -121,6 +121,8 @@ inline void call_func<NetworkImplDft, void>(
CALL_FUNC(use_tensorrt); CALL_FUNC(use_tensorrt);
} else if (func_name == "set_cpu_inplace_mode") { } else if (func_name == "set_cpu_inplace_mode") {
CALL_FUNC(set_cpu_inplace_mode); CALL_FUNC(set_cpu_inplace_mode);
} else if (func_name == "enable_global_layout_transform") {
CALL_FUNC(enable_global_layout_transform);
} else { } else {
THROW_FUNC_ERROR(func_name); THROW_FUNC_ERROR(func_name);
} }
...@@ -186,6 +188,8 @@ inline void call_func<NetworkImplDft, void>( ...@@ -186,6 +188,8 @@ inline void call_func<NetworkImplDft, void>(
return CALL_FUNC(enable_io_txt_dump, file_name); return CALL_FUNC(enable_io_txt_dump, file_name);
} else if (func_name == "enable_io_bin_dump") { } else if (func_name == "enable_io_bin_dump") {
return CALL_FUNC(enable_io_bin_dump, file_name); return CALL_FUNC(enable_io_bin_dump, file_name);
} else if (func_name == "dump_layout_transform_model") {
return CALL_FUNC(dump_layout_transform_model, file_name);
} }
THROW_FUNC_ERROR(func_name); THROW_FUNC_ERROR(func_name);
} }
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/graph/cg.h" #include "megbrain/graph/cg.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
...@@ -364,19 +363,26 @@ void NetworkImplDft::adapt_option_valid() { ...@@ -364,19 +363,26 @@ void NetworkImplDft::adapt_option_valid() {
} }
} }
void NetworkImplDft::global_layout_transform() {
if (m_set_layout_transform) {
m_load_result.output_var_list = mgb::gopt::layout_transform(
m_load_result.output_var_list, m_layout_transform_target);
}
}
void NetworkImplDft::load_model( void NetworkImplDft::load_model(
std::shared_ptr<void> model_mem, size_t size, std::shared_ptr<void> model_mem, size_t size,
std::unordered_map<std::string, LiteAny> separate_config_map) { std::unordered_map<std::string, LiteAny> separate_config_map) {
if (!m_loader) { if (!m_loader) {
m_input_file = m_input_file =
mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false); mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
auto format = mgb::serialization::GraphLoader::identify_graph_dump_format( m_format = mgb::serialization::GraphLoader::identify_graph_dump_format(
*m_input_file); *m_input_file);
if (!format.valid()) { if (!m_format.valid()) {
LITE_THROW("invalid model format"); LITE_THROW("invalid model format");
} }
m_loader = mgb::serialization::GraphLoader::make( m_loader = mgb::serialization::GraphLoader::make(
std::move(m_input_file), format.val()); std::move(m_input_file), m_format.val());
} }
//! applay the user configration to mge model //! applay the user configration to mge model
...@@ -400,7 +406,9 @@ void NetworkImplDft::load_model( ...@@ -400,7 +406,9 @@ void NetworkImplDft::load_model(
use_tensorrt(); use_tensorrt();
} }
m_load_result = m_loader->load(m_load_config, true); m_load_result = m_loader->load(m_load_config, false);
global_layout_transform();
adapt_option_valid(); adapt_option_valid();
...@@ -847,9 +855,6 @@ const char* NetworkImplDft::get_input_name(size_t index) const { ...@@ -847,9 +855,6 @@ const char* NetworkImplDft::get_input_name(size_t index) const {
//! Plugin part //! Plugin part
void NetworkImplDft::enable_profile_performance(std::string profile_json_file) { void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
#if MGB_OPENCL
mgb::CompNode::enable_opencl_profile(true);
#endif
m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get()); m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
m_profiler_output_file = profile_json_file; m_profiler_output_file = profile_json_file;
#else #else
...@@ -889,5 +894,40 @@ void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) co ...@@ -889,5 +894,40 @@ void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) co
LITE_MARK_USED_VAR(log_dir); LITE_MARK_USED_VAR(log_dir);
} }
void NetworkImplDft::enable_global_layout_transform() {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
switch (m_user_config->device_type) {
case LiteDeviceType::LITE_CPU:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
break;
case LiteDeviceType::LITE_CUDA:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
break;
default:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
LITE_WARN(
"lite compnode type: enum value: %d. is unspecial for layout "
"transform",
(int)(m_user_config->device_type));
}
m_set_layout_transform = true;
}
void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) {
if (m_set_layout_transform) {
auto out_file = mgb::serialization::OutputFile::make_fs(
optimized_model_path.c_str(), 'w');
using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
DumpConfig config{1, false, false};
auto dumper = mgb::serialization::GraphDumper::make(
std::move(out_file), m_format.val());
dumper->dump(m_load_result.output_var_list, config);
} else {
LITE_THROW(
ssprintf("dump layout transform model should call "
"enable_global_layout_transform before"));
}
}
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include "network_impl_base.h" #include "network_impl_base.h"
#include "tensor_impl.h" #include "tensor_impl.h"
#include <memory>
#include <unordered_map>
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/bases.h" #include "megbrain/graph/bases.h"
#include "megbrain/plugin/opr_io_dump.h" #include "megbrain/plugin/opr_io_dump.h"
#include "megbrain/plugin/profiler.h" #include "megbrain/plugin/profiler.h"
...@@ -28,9 +31,6 @@ ...@@ -28,9 +31,6 @@
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/thin/hash_table.h"
#include <memory>
#include <unordered_map>
namespace lite { namespace lite {
/*! /*!
...@@ -170,11 +170,20 @@ public: ...@@ -170,11 +170,20 @@ public:
void get_static_memory_alloc_info( void get_static_memory_alloc_info(
const std::string& log_dir = "logs/test") const override; const std::string& log_dir = "logs/test") const override;
//! set global layout transform optimization for network
void enable_global_layout_transform();
//! dump network after global layout transform optimization
void dump_layout_transform_model(std::string optimized_model_path);
private: private:
//! construct the outputspec according to the m_network_io, and set the //! construct the outputspec according to the m_network_io, and set the
//! call_back to the outputspec //! call_back to the outputspec
void make_output_spec(); void make_output_spec();
//! do the global layout transform for the given platform target
void global_layout_transform();
//! modify the execution policy //! modify the execution policy
void modify_exection_policy(); void modify_exection_policy();
...@@ -223,6 +232,7 @@ private: ...@@ -223,6 +232,7 @@ private:
int m_nr_device_type = 0; int m_nr_device_type = 0;
size_t m_nr_threads = 1; size_t m_nr_threads = 1;
bool m_compute_configured_output_only = false; bool m_compute_configured_output_only = false;
bool m_set_layout_transform = false;
mgb::CompNode::Locator m_compnode_locator; mgb::CompNode::Locator m_compnode_locator;
AsyncCallback m_async_callback = nullptr; AsyncCallback m_async_callback = nullptr;
...@@ -233,6 +243,9 @@ private: ...@@ -233,6 +243,9 @@ private:
//! The model load related data //! The model load related data
S m_execution_policy = static_cast<S>(0); S m_execution_policy = static_cast<S>(0);
std::unique_ptr<mgb::serialization::InputFile> m_input_file; std::unique_ptr<mgb::serialization::InputFile> m_input_file;
mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format;
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
mgb::serialization::GraphLoadConfig m_load_config; mgb::serialization::GraphLoadConfig m_load_config;
mgb::serialization::GraphLoader::LoadResult m_load_result; mgb::serialization::GraphLoader::LoadResult m_load_result;
mgb::ComputingGraph::OutputSpec m_output_spec; mgb::ComputingGraph::OutputSpec m_output_spec;
......
...@@ -505,4 +505,33 @@ void Runtime::shared_weight_with_network( ...@@ -505,4 +505,33 @@ void Runtime::shared_weight_with_network(
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }
void Runtime::enable_global_layout_transform(std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"enable_global_layout_transform should be used before model loaded.");
call_func<NetworkImplDft, void>("enable_global_layout_transform", network_impl);
return;
}
LITE_THROW("enable_global_layout_transform is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
NetworkHelper::loaded(network),
"dump_layout_transform_model should be used after model loaded.");
call_func<NetworkImplDft, void>(
"dump_layout_transform_model", network_impl, optimized_model_path);
return;
}
LITE_THROW("dump_layout_transform_model is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -909,6 +909,30 @@ TEST(TestNetWork, LoadPackedModel) { ...@@ -909,6 +909,30 @@ TEST(TestNetWork, LoadPackedModel) {
network->wait(); network->wait();
} }
TEST(TestNetWork, GlabalLayoutTransform) {
// set_log_level(LiteLogLevel::DEBUG);
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
std::string input_name = "data";
std::string dump_model_name = "./shufflenet_after_trans.mge";
NetworkIO IO;
Config config;
std::shared_ptr<Network> network = std::make_shared<Network>(config, IO);
Runtime::enable_global_layout_transform(network);
network->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name);
auto src_ptr = tensor->get_memory_ptr();
auto src_layout = tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
Runtime::dump_layout_transform_model(network, dump_model_name);
network->forward();
network->wait();
ASSERT_TRUE(fopen(dump_model_name.c_str(), "r"));
}
TEST(TestNetWork, GetDeviceType) { TEST(TestNetWork, GetDeviceType) {
auto tensor = get_input_data("./input_data.npy"); auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge"; std::string model_path = "./shufflenet.mge";
......
...@@ -889,6 +889,21 @@ TEST(TestCapiNetWork, ProfileIOdump) { ...@@ -889,6 +889,21 @@ TEST(TestCapiNetWork, ProfileIOdump) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network)); LITE_CAPI_CHECK(LITE_destroy_network(c_network));
} }
TEST(TestCapiNetWork, GlabalLayoutTransform) {
ForwardMgb;
MakeNetwork;
LITE_CAPI_CHECK(LITE_enable_global_layout_transform(c_network));
LoadNetwork;
LITE_CAPI_CHECK(LITE_dump_layout_transform_model(
c_network, "./shufflenet_after_trans.mge"));
SetInput;
ForwardNetwork;
ASSERT_TRUE(fopen("./shufflenet_after_trans.mge", "r"));
GetOutput;
CompareResult;
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
TEST(TestCapiNetWork, GetDeviceType) { TEST(TestCapiNetWork, GetDeviceType) {
lite::Config config; lite::Config config;
auto lite_tensor = lite::get_input_data("./input_data.npy"); auto lite_tensor = lite::get_input_data("./input_data.npy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册