提交 e70c07a2 编写于 作者: M Megvii Engine Team

feat(lite): add global layout transform c/c++ interface for lite

GitOrigin-RevId: 36a4b26b42079611f38c9f817165a26269749ec3
上级 86ee4638
......@@ -97,7 +97,7 @@ struct LITE_API Options {
bool no_profiling_on_shape_change = false;
uint8_t jit_level = 0;
uint8_t comp_node_seq_record_level = 0;
uint8_t graph_opt_level = 2;
uint8_t graph_opt_level = 0;
uint16_t async_exec_level = 1;
//! layout transform options
......@@ -366,6 +366,14 @@ public:
static void shared_weight_with_network(
std::shared_ptr<Network> dst_network,
const std::shared_ptr<Network> src_network);
//! set global layout transform optimization for network
static void enable_global_layout_transform(std::shared_ptr<Network> network);
//! dump network after global layout transform optimization
static void dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path);
};
} // namespace lite
......
......@@ -572,6 +572,22 @@ LITE_API int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out
LITE_API int LITE_get_static_memory_alloc_info(
LiteNetwork network, const char* log_dir);
/**
* \brief enable the global layout transform optimization
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_enable_global_layout_transform(LiteNetwork network);
/**
* \brief dump the model after the global layout transform optimization
* \param[in] dump_file_path The model file path need to dump
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_dump_layout_transform_model(
LiteNetwork network, const char* dump_file_path);
#ifdef __cplusplus
}
#endif
......
......@@ -648,4 +648,21 @@ int LITE_get_static_memory_alloc_info(LiteNetwork network, const char* log_dir)
LITE_CAPI_END();
}
int LITE_enable_global_layout_transform(LiteNetwork network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
std::shared_ptr<lite::Network> network_shared{
static_cast<lite::Network*>(network), [](void*) {}};
lite::Runtime::enable_global_layout_transform(network_shared);
LITE_CAPI_END();
}
int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_path) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
std::shared_ptr<lite::Network> network_shared{
static_cast<lite::Network*>(network), [](void*) {}};
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
LITE_CAPI_END();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -121,6 +121,8 @@ inline void call_func<NetworkImplDft, void>(
CALL_FUNC(use_tensorrt);
} else if (func_name == "set_cpu_inplace_mode") {
CALL_FUNC(set_cpu_inplace_mode);
} else if (func_name == "enable_global_layout_transform") {
CALL_FUNC(enable_global_layout_transform);
} else {
THROW_FUNC_ERROR(func_name);
}
......@@ -186,6 +188,8 @@ inline void call_func<NetworkImplDft, void>(
return CALL_FUNC(enable_io_txt_dump, file_name);
} else if (func_name == "enable_io_bin_dump") {
return CALL_FUNC(enable_io_bin_dump, file_name);
} else if (func_name == "dump_layout_transform_model") {
return CALL_FUNC(dump_layout_transform_model, file_name);
}
THROW_FUNC_ERROR(func_name);
}
......
......@@ -22,7 +22,6 @@
#include "megbrain/common.h"
#include "megbrain/comp_node.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/graph.h"
#include "megbrain/graph/cg.h"
#include "megbrain/opr/io.h"
......@@ -364,19 +363,26 @@ void NetworkImplDft::adapt_option_valid() {
}
}
void NetworkImplDft::global_layout_transform() {
if (m_set_layout_transform) {
m_load_result.output_var_list = mgb::gopt::layout_transform(
m_load_result.output_var_list, m_layout_transform_target);
}
}
void NetworkImplDft::load_model(
std::shared_ptr<void> model_mem, size_t size,
std::unordered_map<std::string, LiteAny> separate_config_map) {
if (!m_loader) {
m_input_file =
mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
auto format = mgb::serialization::GraphLoader::identify_graph_dump_format(
m_format = mgb::serialization::GraphLoader::identify_graph_dump_format(
*m_input_file);
if (!format.valid()) {
if (!m_format.valid()) {
LITE_THROW("invalid model format");
}
m_loader = mgb::serialization::GraphLoader::make(
std::move(m_input_file), format.val());
std::move(m_input_file), m_format.val());
}
//! applay the user configration to mge model
......@@ -400,7 +406,9 @@ void NetworkImplDft::load_model(
use_tensorrt();
}
m_load_result = m_loader->load(m_load_config, true);
m_load_result = m_loader->load(m_load_config, false);
global_layout_transform();
adapt_option_valid();
......@@ -847,9 +855,6 @@ const char* NetworkImplDft::get_input_name(size_t index) const {
//! Plugin part
void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
#if MGB_ENABLE_JSON
#if MGB_OPENCL
mgb::CompNode::enable_opencl_profile(true);
#endif
m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
m_profiler_output_file = profile_json_file;
#else
......@@ -889,5 +894,40 @@ void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) co
LITE_MARK_USED_VAR(log_dir);
}
void NetworkImplDft::enable_global_layout_transform() {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
switch (m_user_config->device_type) {
case LiteDeviceType::LITE_CPU:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
break;
case LiteDeviceType::LITE_CUDA:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
break;
default:
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
LITE_WARN(
"lite compnode type: enum value: %d. is unspecial for layout "
"transform",
(int)(m_user_config->device_type));
}
m_set_layout_transform = true;
}
void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) {
if (m_set_layout_transform) {
auto out_file = mgb::serialization::OutputFile::make_fs(
optimized_model_path.c_str(), 'w');
using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
DumpConfig config{1, false, false};
auto dumper = mgb::serialization::GraphDumper::make(
std::move(out_file), m_format.val());
dumper->dump(m_load_result.output_var_list, config);
} else {
LITE_THROW(
ssprintf("dump layout transform model should call "
"enable_global_layout_transform before"));
}
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -19,6 +19,9 @@
#include "network_impl_base.h"
#include "tensor_impl.h"
#include <memory>
#include <unordered_map>
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/bases.h"
#include "megbrain/plugin/opr_io_dump.h"
#include "megbrain/plugin/profiler.h"
......@@ -28,9 +31,6 @@
#include "megbrain/serialization/serializer.h"
#include "megbrain/utils/thin/hash_table.h"
#include <memory>
#include <unordered_map>
namespace lite {
/*!
......@@ -170,11 +170,20 @@ public:
void get_static_memory_alloc_info(
const std::string& log_dir = "logs/test") const override;
//! set global layout transform optimization for network
void enable_global_layout_transform();
//! dump network after global layout transform optimization
void dump_layout_transform_model(std::string optimized_model_path);
private:
//! construct the outputspec according to the m_network_io, and set the
//! call_back to the outputspec
void make_output_spec();
//! do the global layout transform for the given platform target
void global_layout_transform();
//! modify the execution policy
void modify_exection_policy();
......@@ -223,6 +232,7 @@ private:
int m_nr_device_type = 0;
size_t m_nr_threads = 1;
bool m_compute_configured_output_only = false;
bool m_set_layout_transform = false;
mgb::CompNode::Locator m_compnode_locator;
AsyncCallback m_async_callback = nullptr;
......@@ -233,6 +243,9 @@ private:
//! The model load related data
S m_execution_policy = static_cast<S>(0);
std::unique_ptr<mgb::serialization::InputFile> m_input_file;
mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format;
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
mgb::serialization::GraphLoadConfig m_load_config;
mgb::serialization::GraphLoader::LoadResult m_load_result;
mgb::ComputingGraph::OutputSpec m_output_spec;
......
......@@ -505,4 +505,33 @@ void Runtime::shared_weight_with_network(
LITE_ERROR_HANDLER_END
}
void Runtime::enable_global_layout_transform(std::shared_ptr<Network> network) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
!NetworkHelper::loaded(network),
"enable_global_layout_transform should be used before model loaded.");
call_func<NetworkImplDft, void>("enable_global_layout_transform", network_impl);
return;
}
LITE_THROW("enable_global_layout_transform is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
void Runtime::dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path) {
LITE_ERROR_HANDLER_BEGIN
auto network_impl = NetworkHelper::implement(network);
if (network_impl->get_backend_type() == LiteBackend::LITE_DEFAULT) {
LITE_ASSERT(
NetworkHelper::loaded(network),
"dump_layout_transform_model should be used after model loaded.");
call_func<NetworkImplDft, void>(
"dump_layout_transform_model", network_impl, optimized_model_path);
return;
}
LITE_THROW("dump_layout_transform_model is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -909,6 +909,30 @@ TEST(TestNetWork, LoadPackedModel) {
network->wait();
}
TEST(TestNetWork, GlabalLayoutTransform) {
// set_log_level(LiteLogLevel::DEBUG);
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
std::string input_name = "data";
std::string dump_model_name = "./shufflenet_after_trans.mge";
NetworkIO IO;
Config config;
std::shared_ptr<Network> network = std::make_shared<Network>(config, IO);
Runtime::enable_global_layout_transform(network);
network->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name);
auto src_ptr = tensor->get_memory_ptr();
auto src_layout = tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
Runtime::dump_layout_transform_model(network, dump_model_name);
network->forward();
network->wait();
ASSERT_TRUE(fopen(dump_model_name.c_str(), "r"));
}
TEST(TestNetWork, GetDeviceType) {
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
......
......@@ -889,6 +889,21 @@ TEST(TestCapiNetWork, ProfileIOdump) {
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
TEST(TestCapiNetWork, GlabalLayoutTransform) {
ForwardMgb;
MakeNetwork;
LITE_CAPI_CHECK(LITE_enable_global_layout_transform(c_network));
LoadNetwork;
LITE_CAPI_CHECK(LITE_dump_layout_transform_model(
c_network, "./shufflenet_after_trans.mge"));
SetInput;
ForwardNetwork;
ASSERT_TRUE(fopen("./shufflenet_after_trans.mge", "r"));
GetOutput;
CompareResult;
LITE_CAPI_CHECK(LITE_destroy_network(c_network));
}
TEST(TestCapiNetWork, GetDeviceType) {
lite::Config config;
auto lite_tensor = lite::get_input_data("./input_data.npy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册