提交 cc218550 编写于 作者: M Megvii Engine Team

feat(lite): load_and_run support optimize for inference

GitOrigin-RevId: d9abb8de9eb81c357c3f5110b2467582d80b6c2c
上级 9bbe5500
......@@ -36,16 +36,19 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) {
auto model_type = get_model_type(model_path);
if (ModelType::LITE_MODEL == model_type) {
if (FLAGS_lite) {
mgb_log("run model force lite mode\n");
return std::make_shared<ModelLite>(model_path);
} else if (ModelType::MEGDL_MODEL == model_type) {
if (FLAGS_lite)
return std::make_shared<ModelLite>(model_path);
else
} else if (FLAGS_mdl) {
mgb_log("run model force mdl mode\n");
return std::make_shared<ModelMdl>(model_path);
} else if (ModelType::LITE_MODEL == model_type) {
return std::make_shared<ModelLite>(model_path);
} else {
return nullptr;
mgb_assert(ModelType::MEGDL_MODEL == model_type);
return std::make_shared<ModelMdl>(model_path);
}
}
DEFINE_bool(lite, false, "use megengine lite interface to run model");
DEFINE_bool(mdl, false, "use megengine mdl interface to run model");
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -4,6 +4,7 @@
#include "helpers/common.h"
#include "megbrain/utils/json.h"
DECLARE_bool(lite);
DECLARE_bool(mdl);
namespace lar {
/*!
......
......@@ -42,6 +42,8 @@ public:
return m_load_result;
}
void update_mdl_load_result(const mgb::SymbolVarArray& output_var_array);
//! get load config for megDL model
mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; }
......
......@@ -31,6 +31,13 @@ void FastRunOption::config_model_internel<ModelLite>(
LITE_LOG("enable fast-run strategy for algo profile");
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy;
} else if ((!m_fast_run_cache.empty() &&
!access(m_fast_run_cache.c_str(), F_OK))) {
LITE_LOG(
"detect fast-run cache usable set LITE_ALGO_PROFILE for algo "
"profile");
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
} else {
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
}
......
......@@ -299,6 +299,75 @@ void FuseConvBiasElemwiseAddOption::config_model(
CONFIG_MODEL_FUN;
}
///////////////////////// optimize for inference options ///////////////
bool OptimizeForInferenceOption::m_valid;
namespace lar {
template <>
void OptimizeForInferenceOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
LITE_MARK_USED_VAR(model);
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto optimize_for_infer =
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"])
->get_value();
if (optimize_for_infer) {
LITE_THROW(
"optimize for inference not supported in lite "
"model");
}
}
}
template <>
void OptimizeForInferenceOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto optimize_for_infer =
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"])
->get_value();
if (optimize_for_infer) {
mgb_log("enable optimize for inference optimization");
auto&& load_result = model->get_mdl_load_result();
mgb::cg::GraphCommonOptimizeOptions opt =
model->get_mdl_load_result().graph->options().graph_opt;
auto inference_opt2 = mgb::gopt::OptimizeForInferenceOptions(opt);
auto output_var_list = mgb::gopt::optimize_for_inference(
load_result.output_var_list, inference_opt2);
model->get_mdl_load_result().update_output_var_list(output_var_list);
model->get_mdl_load_result().graph->options().graph_opt.clear();
}
}
}
} // namespace lar
void OptimizeForInferenceOption::update() {
m_option_name = "optimize_for_inference";
m_option = {{"optimize_for_inference", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"])
->set_value(FLAGS_optimize_for_inference);
}
bool OptimizeForInferenceOption::is_valid() {
bool ret = FLAGS_optimize_for_inference;
return ret || m_valid;
}
std::shared_ptr<OptionBase> OptimizeForInferenceOption::create_option() {
static std::shared_ptr<OptimizeForInferenceOption> option(
new OptimizeForInferenceOption);
if (OptimizeForInferenceOption::is_valid()) {
option->update();
return option;
} else {
return nullptr;
}
}
void OptimizeForInferenceOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
CONFIG_MODEL_FUN;
}
///////////////////////// graph retrict options /////////////////////////
bool GraphRecordOption::m_valid;
namespace lar {
......@@ -646,6 +715,9 @@ DEFINE_bool(
enable_fuse_conv_bias_with_z, false,
"fuse conv, bias (elemwise add), z(elemwise add) into one opr "
"(only support on GPU)");
DEFINE_bool(
optimize_for_inference, false,
"whether to optimize_for_inference, fuse bn and many base optimize");
///////////////////////// graph retrict options /////////////////////////
DEFINE_bool(
......@@ -699,6 +771,11 @@ REGIST_OPTION_CREATOR(
REGIST_OPTION_VALIDATER(
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid);
REGIST_OPTION_CREATOR(
optimize_for_inference, lar::OptimizeForInferenceOption::create_option);
REGIST_OPTION_VALIDATER(
optimize_for_inference, lar::OptimizeForInferenceOption::set_valid);
REGIST_OPTION_CREATOR(
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option);
REGIST_OPTION_VALIDATER(
......
......@@ -5,6 +5,7 @@
#include "option_base.h"
DECLARE_bool(enable_fuse_preprocess);
DECLARE_bool(optimize_for_inference);
DECLARE_bool(fuse_grain);
DECLARE_bool(weight_preprocess);
DECLARE_bool(enable_fuse_conv_bias_nonlinearity);
......@@ -216,6 +217,34 @@ private:
uint64_t workspace_limit;
};
///////////////////////// optimize for inference options /////////////////////////
class OptimizeForInferenceOption final : public OptionBase {
public:
static bool is_valid();
static std::shared_ptr<OptionBase> create_option();
void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
static void set_valid(bool val) { m_valid = val; }
std::string option_name() const override { return m_option_name; };
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
OptimizeForInferenceOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
std::string m_option_name;
static bool m_valid;
OptionValMap m_option;
};
///////////////////////// other options for optimization /////////////////
class JITOption final : public OptionBase {
public:
......
......@@ -366,19 +366,7 @@ void NetworkImplDft::layout_transform_optimization() {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
auto output_var_array = mgb::gopt::layout_transform(
m_load_result.output_var_list, m_layout_transform_target);
// replace symvar in output_var_list
for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx];
m_load_result.output_var_list[idx] = output_var_array[idx];
}
// replace symvar in output_var_map_id
for (auto&& item : m_load_result.output_var_map_id) {
item.second = out_var_map[item.second];
}
// replace symvar in output_var_map
for (auto&& item : m_load_result.output_var_map) {
item.second = out_var_map[item.second];
}
m_load_result.update_output_var_list(output_var_array);
} else if (m_user_config->auto_optimize_inference) {
//! set model weight preprocess
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true;
......
......@@ -8,6 +8,7 @@
#include "../src/misc.h"
#include "lite/network.h"
#include "lite/tensor.h"
#include "megbrain/comp_node.h"
#include "megbrain/graph/bases.h"
#include "megbrain/plugin/opr_io_dump.h"
#include "megbrain/plugin/profiler.h"
......@@ -167,4 +168,18 @@ __attribute__((unused)) static std::shared_ptr<Tensor> mgb_lar(
#endif
static inline bool check_gpu_available(size_t num) {
if (mgb::CompNode::get_device_count(mgb::CompNode::DeviceType::CUDA) < num) {
mgb_log_warn("skip test case that requires %zu GPU(s)", num);
return false;
}
return true;
}
#define REQUIRE_CUDA() \
{ \
if (!check_gpu_available(1)) { \
return; \
} \
} \
while (0)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#include <gtest/gtest.h>
#include <string.h>
#include <memory>
#include "test_common.h"
#include "test_options.h"
using namespace lar;
DECLARE_bool(lite);
DECLARE_bool(cpu);
DECLARE_bool(optimize_for_inference);
#if LITE_WITH_CUDA
DECLARE_bool(cuda);
#endif
namespace {
BOOL_OPTION_WRAP(optimize_for_inference);
BOOL_OPTION_WRAP(lite);
BOOL_OPTION_WRAP(cpu);
#if LITE_WITH_CUDA
BOOL_OPTION_WRAP(cuda);
#endif
} // anonymous namespace
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE) {
DEFINE_WRAP(cpu);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(optimize_for_inference);
}
#if LITE_WITH_OPENCL
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_OPENCL) {
REQUIRE_OPENCL();
DEFINE_WRAP(opencl);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(optimize_for_inference);
}
#endif
#if LITE_WITH_CUDA
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_CUDA) {
REQUIRE_CUDA();
DEFINE_WRAP(cuda);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(optimize_for_inference);
}
#endif
#include <gtest/gtest.h>
#include <string.h>
#include <memory>
#include "test_common.h"
#include "test_options.h"
using namespace lar;
......
......@@ -109,6 +109,16 @@ struct GraphCommonOptimizeOptions {
///< support on Nvidia GPU
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;
void clear() {
f16_io_f32_comp = false;
f16_io_comp = false;
fuse_conv_bias_nonlinearity = false;
fuse_conv_bias_with_z = false;
weight_preprocess = false;
fuse_preprocess = false;
fuse_grain = false;
layout_transform = LayoutTransform::DEFAULT;
}
#define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \
......
......@@ -312,6 +312,9 @@ public:
};
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
OptimizeForInferenceOptions() = default;
OptimizeForInferenceOptions(const cg::GraphCommonOptimizeOptions& opt)
: cg::GraphCommonOptimizeOptions(opt){};
uint64_t serialize() {
uint64_t ret = 0;
ret |= (uint64_t)layout_transform << 32;
......
......@@ -17,6 +17,25 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
return ret;
}
void GraphLoader::LoadResult::update_output_var_list(
const SymbolVarArray& output_var_array) {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
mgb_assert(output_var_array.size() == output_var_list.size());
// replace symvar in output_var_list
for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
out_var_map[output_var_list[idx]] = output_var_array[idx];
output_var_list[idx] = output_var_array[idx];
}
// replace symvar in output_var_map_id
for (auto&& item : output_var_map_id) {
item.second = out_var_map[item.second];
}
// replace symvar in output_var_map
for (auto&& item : output_var_map) {
item.second = out_var_map[item.second].rename(item.first);
}
}
void GraphLoader::LoadResult::graph_compile_ahead() {
//! when force_output_use_user_specified_memory is set, the output var may
//! be changed by gopt, then the var in LoadResult can not exist, so here
......
......@@ -45,6 +45,13 @@ public:
//! GraphDumper::dump
SymbolVarArray output_var_list;
/**
* \brief update output_var_list with output_var_map, output_var_map_id
*
*/
MGE_WIN_DECLSPEC_FUC void update_output_var_list(
const SymbolVarArray& output_var_array);
/*!
* \brief call graph->compile() but also checks for comp seq rec
*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册