提交 49d92d9c 编写于 作者: M Megvii Engine Team

feat(lite): feat layout transform interface for lite model

GitOrigin-RevId: 57c7678419dabe17de562d45f56dafcc2bd4eef0
上级 2a900a69
......@@ -19,6 +19,9 @@ ModelLite::ModelLite(const std::string& path) : model_path(path) {
};
void ModelLite::load_model() {
m_network = std::make_shared<lite::Network>(config, IO);
if (enable_layout_transform) {
lite::Runtime::enable_global_layout_transform(m_network);
}
if (share_model_mem) {
//! WARNNING:maybe not right to share param memmory for this
LITE_WARN("enable share model memory");
......
......@@ -39,6 +39,10 @@ public:
//! wait the end of asynchronous function execution
void wait() override;
//! enable global layout transform
void set_layout_transform(bool state) { enable_layout_transform = state; }
//! get the network of lite model
std::shared_ptr<lite::Network>& get_lite_network() { return m_network; }
......@@ -59,6 +63,7 @@ public:
private:
bool share_model_mem;
bool enable_layout_transform;
std::string model_path;
DataParser parser;
......
......@@ -16,9 +16,30 @@ namespace lar {
template <>
void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> /* model */) {
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
LITE_THROW("lite model don't support global graph optimization");
if (m_layout_transform) {
if (m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CPU) {
model->get_config().device_type = LiteDeviceType::LITE_CPU;
}
#if LITE_WITH_CUDA
else if (
m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CUDA) {
model->get_config().device_type = LiteDeviceType::LITE_CUDA;
}
#endif
model->set_layout_transform(true);
}
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
auto&& network = model->get_lite_network();
if (!m_layout_transform_dump_file.empty()) {
lite::Runtime::dump_layout_transform_model(
network, m_layout_transform_dump_file);
}
}
}
}
......@@ -26,14 +47,14 @@ template <>
void GoptLayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (layout_transform) {
if (m_layout_transform) {
auto&& load_result = model->get_mdl_load_result();
load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, layout_transform_target);
load_result.output_var_list, m_layout_transform_target);
if (!layout_transform_dump_file.empty()) {
if (!m_layout_transform_dump_file.empty()) {
auto out_file = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'w');
m_layout_transform_dump_file.c_str(), 'w');
auto testcase_num = model->get_testcase_num();
if (testcase_num) {
......@@ -56,7 +77,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
mgb::serialization::GraphDumper::DumpConfig config{1, false, false};
for (size_t i = 0; i < testcase_num; ++i) {
auto casefile = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'a');
m_layout_transform_dump_file.c_str(), 'a');
auto casedumper = model->get_dumper(std::move(casefile));
casedumper->dump(testcase.output_var_list, config);
if (i != testcase_num - 1) {
......@@ -80,29 +101,37 @@ using namespace lar;
GoptLayoutOption::GoptLayoutOption() {
m_option_name = "gopt_layout";
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" &&
FLAGS_layout_transform != "opencl") {
layout_transform = false;
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
if (FLAGS_layout_transform != "cpu"
#if LITE_WITH_CUDA
&& FLAGS_layout_transform != "cuda"
#endif
) {
m_layout_transform = false;
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
} else {
layout_transform = true;
if (FLAGS_layout_transform == "cuda") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
} else if (FLAGS_layout_transform == "cpu") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
} else if (FLAGS_layout_transform == "opencl") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::OPENCL;
m_layout_transform = true;
if (FLAGS_layout_transform == "cpu") {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
}
#if LITE_WITH_CUDA
else if (FLAGS_layout_transform == "cuda") {
m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
}
#endif
}
layout_transform_dump_file = FLAGS_layout_transform_dump;
m_layout_transform_dump_file = FLAGS_layout_transform_dump;
}
bool GoptLayoutOption::is_valid() {
bool ret = false;
if (!FLAGS_layout_transform.empty()) {
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" &&
FLAGS_layout_transform != "opencl") {
if (FLAGS_layout_transform != "cpu"
#if LITE_WITH_CUDA
&& FLAGS_layout_transform != "cuda"
#endif
) {
mgb_assert(
false,
"unsupported target(got:%s) for global layout "
......
......@@ -37,9 +37,9 @@ private:
//! config template for different model
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {}
bool layout_transform;
bool m_layout_transform;
std::string m_option_name;
std::string layout_transform_dump_file;
mgb::gopt::GraphTuningOptions::Target layout_transform_target;
std::string m_layout_transform_dump_file;
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
};
} // namespace lar
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册