提交 49d92d9c 编写于 作者: M Megvii Engine Team

feat(lite): feat layout transform interface for lite model

GitOrigin-RevId: 57c7678419dabe17de562d45f56dafcc2bd4eef0
上级 2a900a69
...@@ -19,6 +19,9 @@ ModelLite::ModelLite(const std::string& path) : model_path(path) { ...@@ -19,6 +19,9 @@ ModelLite::ModelLite(const std::string& path) : model_path(path) {
}; };
void ModelLite::load_model() { void ModelLite::load_model() {
m_network = std::make_shared<lite::Network>(config, IO); m_network = std::make_shared<lite::Network>(config, IO);
if (enable_layout_transform) {
lite::Runtime::enable_global_layout_transform(m_network);
}
if (share_model_mem) { if (share_model_mem) {
//! WARNNING:maybe not right to share param memmory for this //! WARNNING:maybe not right to share param memmory for this
LITE_WARN("enable share model memory"); LITE_WARN("enable share model memory");
......
...@@ -39,6 +39,10 @@ public: ...@@ -39,6 +39,10 @@ public:
//! wait the end of asynchronous function execution //! wait the end of asynchronous function execution
void wait() override; void wait() override;
//! enable global layout transform
void set_layout_transform(bool state) { enable_layout_transform = state; }
//! get the network of lite model //! get the network of lite model
std::shared_ptr<lite::Network>& get_lite_network() { return m_network; } std::shared_ptr<lite::Network>& get_lite_network() { return m_network; }
...@@ -59,6 +63,7 @@ public: ...@@ -59,6 +63,7 @@ public:
private: private:
bool share_model_mem; bool share_model_mem;
bool enable_layout_transform;
std::string model_path; std::string model_path;
DataParser parser; DataParser parser;
......
...@@ -16,9 +16,30 @@ namespace lar { ...@@ -16,9 +16,30 @@ namespace lar {
template <> template <>
void GoptLayoutOption::config_model_internel<ModelLite>( void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> /* model */) { RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
LITE_THROW("lite model don't support global graph optimization"); if (m_layout_transform) {
if (m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CPU) {
model->get_config().device_type = LiteDeviceType::LITE_CPU;
}
#if LITE_WITH_CUDA
else if (
m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CUDA) {
model->get_config().device_type = LiteDeviceType::LITE_CUDA;
}
#endif
model->set_layout_transform(true);
}
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
auto&& network = model->get_lite_network();
if (!m_layout_transform_dump_file.empty()) {
lite::Runtime::dump_layout_transform_model(
network, m_layout_transform_dump_file);
}
}
} }
} }
...@@ -26,14 +47,14 @@ template <> ...@@ -26,14 +47,14 @@ template <>
void GoptLayoutOption::config_model_internel<ModelMdl>( void GoptLayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (layout_transform) { if (m_layout_transform) {
auto&& load_result = model->get_mdl_load_result(); auto&& load_result = model->get_mdl_load_result();
load_result.output_var_list = mgb::gopt::layout_transform( load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, layout_transform_target); load_result.output_var_list, m_layout_transform_target);
if (!layout_transform_dump_file.empty()) { if (!m_layout_transform_dump_file.empty()) {
auto out_file = mgb::serialization::OutputFile::make_fs( auto out_file = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'w'); m_layout_transform_dump_file.c_str(), 'w');
auto testcase_num = model->get_testcase_num(); auto testcase_num = model->get_testcase_num();
if (testcase_num) { if (testcase_num) {
...@@ -56,7 +77,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>( ...@@ -56,7 +77,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
mgb::serialization::GraphDumper::DumpConfig config{1, false, false}; mgb::serialization::GraphDumper::DumpConfig config{1, false, false};
for (size_t i = 0; i < testcase_num; ++i) { for (size_t i = 0; i < testcase_num; ++i) {
auto casefile = mgb::serialization::OutputFile::make_fs( auto casefile = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'a'); m_layout_transform_dump_file.c_str(), 'a');
auto casedumper = model->get_dumper(std::move(casefile)); auto casedumper = model->get_dumper(std::move(casefile));
casedumper->dump(testcase.output_var_list, config); casedumper->dump(testcase.output_var_list, config);
if (i != testcase_num - 1) { if (i != testcase_num - 1) {
...@@ -80,29 +101,37 @@ using namespace lar; ...@@ -80,29 +101,37 @@ using namespace lar;
GoptLayoutOption::GoptLayoutOption() { GoptLayoutOption::GoptLayoutOption() {
m_option_name = "gopt_layout"; m_option_name = "gopt_layout";
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" && if (FLAGS_layout_transform != "cpu"
FLAGS_layout_transform != "opencl") { #if LITE_WITH_CUDA
layout_transform = false; && FLAGS_layout_transform != "cuda"
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC; #endif
) {
m_layout_transform = false;
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
} else { } else {
layout_transform = true; m_layout_transform = true;
if (FLAGS_layout_transform == "cuda") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA; if (FLAGS_layout_transform == "cpu") {
} else if (FLAGS_layout_transform == "cpu") { m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU; }
} else if (FLAGS_layout_transform == "opencl") { #if LITE_WITH_CUDA
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::OPENCL; else if (FLAGS_layout_transform == "cuda") {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
} }
#endif
} }
layout_transform_dump_file = FLAGS_layout_transform_dump; m_layout_transform_dump_file = FLAGS_layout_transform_dump;
} }
bool GoptLayoutOption::is_valid() { bool GoptLayoutOption::is_valid() {
bool ret = false; bool ret = false;
if (!FLAGS_layout_transform.empty()) { if (!FLAGS_layout_transform.empty()) {
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" && if (FLAGS_layout_transform != "cpu"
FLAGS_layout_transform != "opencl") { #if LITE_WITH_CUDA
&& FLAGS_layout_transform != "cuda"
#endif
) {
mgb_assert( mgb_assert(
false, false,
"unsupported target(got:%s) for global layout " "unsupported target(got:%s) for global layout "
......
...@@ -37,9 +37,9 @@ private: ...@@ -37,9 +37,9 @@ private:
//! config template for different model //! config template for different model
template <typename ModelImpl> template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {} void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {}
bool layout_transform; bool m_layout_transform;
std::string m_option_name; std::string m_option_name;
std::string layout_transform_dump_file; std::string m_layout_transform_dump_file;
mgb::gopt::GraphTuningOptions::Target layout_transform_target; mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
}; };
} // namespace lar } // namespace lar
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册