提交 cc218550 编写于 作者: M Megvii Engine Team

feat(lite): load_and_run support optimize for inference

GitOrigin-RevId: d9abb8de9eb81c357c3f5110b2467582d80b6c2c
上级 9bbe5500
...@@ -36,16 +36,19 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) { ...@@ -36,16 +36,19 @@ std::shared_ptr<ModelBase> ModelBase::create_model(std::string model_path) {
auto model_type = get_model_type(model_path); auto model_type = get_model_type(model_path);
if (ModelType::LITE_MODEL == model_type) { if (FLAGS_lite) {
mgb_log("run model force lite mode\n");
return std::make_shared<ModelLite>(model_path);
} else if (FLAGS_mdl) {
mgb_log("run model force mdl mode\n");
return std::make_shared<ModelMdl>(model_path);
} else if (ModelType::LITE_MODEL == model_type) {
return std::make_shared<ModelLite>(model_path); return std::make_shared<ModelLite>(model_path);
} else if (ModelType::MEGDL_MODEL == model_type) {
if (FLAGS_lite)
return std::make_shared<ModelLite>(model_path);
else
return std::make_shared<ModelMdl>(model_path);
} else { } else {
return nullptr; mgb_assert(ModelType::MEGDL_MODEL == model_type);
return std::make_shared<ModelMdl>(model_path);
} }
} }
DEFINE_bool(lite, false, "use megengine lite interface to run model"); DEFINE_bool(lite, false, "use megengine lite interface to run model");
DEFINE_bool(mdl, false, "use megengine mdl interface to run model");
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "helpers/common.h" #include "helpers/common.h"
#include "megbrain/utils/json.h" #include "megbrain/utils/json.h"
DECLARE_bool(lite); DECLARE_bool(lite);
DECLARE_bool(mdl);
namespace lar { namespace lar {
/*! /*!
......
...@@ -42,6 +42,8 @@ public: ...@@ -42,6 +42,8 @@ public:
return m_load_result; return m_load_result;
} }
void update_mdl_load_result(const mgb::SymbolVarArray& output_var_array);
//! get load config for megDL model //! get load config for megDL model
mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; } mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; }
......
...@@ -31,6 +31,13 @@ void FastRunOption::config_model_internel<ModelLite>( ...@@ -31,6 +31,13 @@ void FastRunOption::config_model_internel<ModelLite>(
LITE_LOG("enable fast-run strategy for algo profile"); LITE_LOG("enable fast-run strategy for algo profile");
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy; static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy;
} else if ((!m_fast_run_cache.empty() &&
!access(m_fast_run_cache.c_str(), F_OK))) {
LITE_LOG(
"detect fast-run cache usable set LITE_ALGO_PROFILE for algo "
"profile");
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
} else { } else {
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy; strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
} }
......
...@@ -299,6 +299,75 @@ void FuseConvBiasElemwiseAddOption::config_model( ...@@ -299,6 +299,75 @@ void FuseConvBiasElemwiseAddOption::config_model(
CONFIG_MODEL_FUN; CONFIG_MODEL_FUN;
} }
///////////////////////// optimize for inference options ///////////////
bool OptimizeForInferenceOption::m_valid;
namespace lar {
template <>
void OptimizeForInferenceOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
LITE_MARK_USED_VAR(model);
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto optimize_for_infer =
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"])
->get_value();
if (optimize_for_infer) {
LITE_THROW(
"optimize for inference not supported in lite "
"model");
}
}
}
template <>
void OptimizeForInferenceOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto optimize_for_infer =
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"])
->get_value();
if (optimize_for_infer) {
mgb_log("enable optimize for inference optimization");
auto&& load_result = model->get_mdl_load_result();
mgb::cg::GraphCommonOptimizeOptions opt =
model->get_mdl_load_result().graph->options().graph_opt;
auto inference_opt2 = mgb::gopt::OptimizeForInferenceOptions(opt);
auto output_var_list = mgb::gopt::optimize_for_inference(
load_result.output_var_list, inference_opt2);
model->get_mdl_load_result().update_output_var_list(output_var_list);
model->get_mdl_load_result().graph->options().graph_opt.clear();
}
}
}
} // namespace lar
void OptimizeForInferenceOption::update() {
m_option_name = "optimize_for_inference";
m_option = {{"optimize_for_inference", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["optimize_for_inference"])
->set_value(FLAGS_optimize_for_inference);
}
bool OptimizeForInferenceOption::is_valid() {
bool ret = FLAGS_optimize_for_inference;
return ret || m_valid;
}
std::shared_ptr<OptionBase> OptimizeForInferenceOption::create_option() {
static std::shared_ptr<OptimizeForInferenceOption> option(
new OptimizeForInferenceOption);
if (OptimizeForInferenceOption::is_valid()) {
option->update();
return option;
} else {
return nullptr;
}
}
void OptimizeForInferenceOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
CONFIG_MODEL_FUN;
}
///////////////////////// graph retrict options ///////////////////////// ///////////////////////// graph retrict options /////////////////////////
bool GraphRecordOption::m_valid; bool GraphRecordOption::m_valid;
namespace lar { namespace lar {
...@@ -646,6 +715,9 @@ DEFINE_bool( ...@@ -646,6 +715,9 @@ DEFINE_bool(
enable_fuse_conv_bias_with_z, false, enable_fuse_conv_bias_with_z, false,
"fuse conv, bias (elemwise add), z(elemwise add) into one opr " "fuse conv, bias (elemwise add), z(elemwise add) into one opr "
"(only support on GPU)"); "(only support on GPU)");
DEFINE_bool(
optimize_for_inference, false,
"whether to optimize_for_inference, fuse bn and many base optimize");
///////////////////////// graph retrict options ///////////////////////// ///////////////////////// graph retrict options /////////////////////////
DEFINE_bool( DEFINE_bool(
...@@ -699,6 +771,11 @@ REGIST_OPTION_CREATOR( ...@@ -699,6 +771,11 @@ REGIST_OPTION_CREATOR(
REGIST_OPTION_VALIDATER( REGIST_OPTION_VALIDATER(
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid); fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::set_valid);
REGIST_OPTION_CREATOR(
optimize_for_inference, lar::OptimizeForInferenceOption::create_option);
REGIST_OPTION_VALIDATER(
optimize_for_inference, lar::OptimizeForInferenceOption::set_valid);
REGIST_OPTION_CREATOR( REGIST_OPTION_CREATOR(
fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option); fuse_conv_bias_with_z, lar::FuseConvBiasElemwiseAddOption::create_option);
REGIST_OPTION_VALIDATER( REGIST_OPTION_VALIDATER(
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "option_base.h" #include "option_base.h"
DECLARE_bool(enable_fuse_preprocess); DECLARE_bool(enable_fuse_preprocess);
DECLARE_bool(optimize_for_inference);
DECLARE_bool(fuse_grain); DECLARE_bool(fuse_grain);
DECLARE_bool(weight_preprocess); DECLARE_bool(weight_preprocess);
DECLARE_bool(enable_fuse_conv_bias_nonlinearity); DECLARE_bool(enable_fuse_conv_bias_nonlinearity);
...@@ -216,6 +217,34 @@ private: ...@@ -216,6 +217,34 @@ private:
uint64_t workspace_limit; uint64_t workspace_limit;
}; };
///////////////////////// optimize for inference options /////////////////////////
class OptimizeForInferenceOption final : public OptionBase {
public:
static bool is_valid();
static std::shared_ptr<OptionBase> create_option();
void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
static void set_valid(bool val) { m_valid = val; }
std::string option_name() const override { return m_option_name; };
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
OptimizeForInferenceOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
std::string m_option_name;
static bool m_valid;
OptionValMap m_option;
};
///////////////////////// other options for optimization ///////////////// ///////////////////////// other options for optimization /////////////////
class JITOption final : public OptionBase { class JITOption final : public OptionBase {
public: public:
......
...@@ -366,19 +366,7 @@ void NetworkImplDft::layout_transform_optimization() { ...@@ -366,19 +366,7 @@ void NetworkImplDft::layout_transform_optimization() {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
auto output_var_array = mgb::gopt::layout_transform( auto output_var_array = mgb::gopt::layout_transform(
m_load_result.output_var_list, m_layout_transform_target); m_load_result.output_var_list, m_layout_transform_target);
// replace symvar in output_var_list m_load_result.update_output_var_list(output_var_array);
for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx];
m_load_result.output_var_list[idx] = output_var_array[idx];
}
// replace symvar in output_var_map_id
for (auto&& item : m_load_result.output_var_map_id) {
item.second = out_var_map[item.second];
}
// replace symvar in output_var_map
for (auto&& item : m_load_result.output_var_map) {
item.second = out_var_map[item.second];
}
} else if (m_user_config->auto_optimize_inference) { } else if (m_user_config->auto_optimize_inference) {
//! set model weight preprocess //! set model weight preprocess
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true; m_load_config.comp_graph->options().graph_opt.weight_preprocess = true;
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../src/misc.h" #include "../src/misc.h"
#include "lite/network.h" #include "lite/network.h"
#include "lite/tensor.h" #include "lite/tensor.h"
#include "megbrain/comp_node.h"
#include "megbrain/graph/bases.h" #include "megbrain/graph/bases.h"
#include "megbrain/plugin/opr_io_dump.h" #include "megbrain/plugin/opr_io_dump.h"
#include "megbrain/plugin/profiler.h" #include "megbrain/plugin/profiler.h"
...@@ -167,4 +168,18 @@ __attribute__((unused)) static std::shared_ptr<Tensor> mgb_lar( ...@@ -167,4 +168,18 @@ __attribute__((unused)) static std::shared_ptr<Tensor> mgb_lar(
#endif #endif
static inline bool check_gpu_available(size_t num) {
if (mgb::CompNode::get_device_count(mgb::CompNode::DeviceType::CUDA) < num) {
mgb_log_warn("skip test case that requires %zu GPU(s)", num);
return false;
}
return true;
}
#define REQUIRE_CUDA() \
{ \
if (!check_gpu_available(1)) { \
return; \
} \
} \
while (0)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#include <gtest/gtest.h>
#include <string.h>
#include <memory>
#include "test_common.h"
#include "test_options.h"
using namespace lar;
DECLARE_bool(lite);
DECLARE_bool(cpu);
DECLARE_bool(optimize_for_inference);
#if LITE_WITH_CUDA
DECLARE_bool(cuda);
#endif
namespace {
BOOL_OPTION_WRAP(optimize_for_inference);
BOOL_OPTION_WRAP(lite);
BOOL_OPTION_WRAP(cpu);
#if LITE_WITH_CUDA
BOOL_OPTION_WRAP(cuda);
#endif
} // anonymous namespace
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE) {
DEFINE_WRAP(cpu);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(optimize_for_inference);
}
#if LITE_WITH_OPENCL
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_OPENCL) {
REQUIRE_OPENCL();
DEFINE_WRAP(opencl);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(optimize_for_inference);
}
#endif
#if LITE_WITH_CUDA
TEST(TestLarOption, OPTIMIZE_FOR_INFERENCE_CUDA) {
REQUIRE_CUDA();
DEFINE_WRAP(cuda);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(optimize_for_inference);
}
#endif
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string.h> #include <string.h>
#include <memory> #include <memory>
#include "test_common.h"
#include "test_options.h" #include "test_options.h"
using namespace lar; using namespace lar;
......
...@@ -109,6 +109,16 @@ struct GraphCommonOptimizeOptions { ...@@ -109,6 +109,16 @@ struct GraphCommonOptimizeOptions {
///< support on Nvidia GPU ///< support on Nvidia GPU
}; };
LayoutTransform layout_transform = LayoutTransform::DEFAULT; LayoutTransform layout_transform = LayoutTransform::DEFAULT;
void clear() {
f16_io_f32_comp = false;
f16_io_comp = false;
fuse_conv_bias_nonlinearity = false;
fuse_conv_bias_with_z = false;
weight_preprocess = false;
fuse_preprocess = false;
fuse_grain = false;
layout_transform = LayoutTransform::DEFAULT;
}
#define SET(n) \ #define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \ GraphCommonOptimizeOptions& enable_##n() { \
......
...@@ -312,6 +312,9 @@ public: ...@@ -312,6 +312,9 @@ public:
}; };
struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
OptimizeForInferenceOptions() = default;
OptimizeForInferenceOptions(const cg::GraphCommonOptimizeOptions& opt)
: cg::GraphCommonOptimizeOptions(opt){};
uint64_t serialize() { uint64_t serialize() {
uint64_t ret = 0; uint64_t ret = 0;
ret |= (uint64_t)layout_transform << 32; ret |= (uint64_t)layout_transform << 32;
......
...@@ -17,6 +17,25 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile( ...@@ -17,6 +17,25 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
return ret; return ret;
} }
void GraphLoader::LoadResult::update_output_var_list(
const SymbolVarArray& output_var_array) {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
mgb_assert(output_var_array.size() == output_var_list.size());
// replace symvar in output_var_list
for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
out_var_map[output_var_list[idx]] = output_var_array[idx];
output_var_list[idx] = output_var_array[idx];
}
// replace symvar in output_var_map_id
for (auto&& item : output_var_map_id) {
item.second = out_var_map[item.second];
}
// replace symvar in output_var_map
for (auto&& item : output_var_map) {
item.second = out_var_map[item.second].rename(item.first);
}
}
void GraphLoader::LoadResult::graph_compile_ahead() { void GraphLoader::LoadResult::graph_compile_ahead() {
//! when force_output_use_user_specified_memory is set, the output var may //! when force_output_use_user_specified_memory is set, the output var may
//! be changed by gopt, then the var in LoadResult can not exist, so here //! be changed by gopt, then the var in LoadResult can not exist, so here
......
...@@ -45,6 +45,13 @@ public: ...@@ -45,6 +45,13 @@ public:
//! GraphDumper::dump //! GraphDumper::dump
SymbolVarArray output_var_list; SymbolVarArray output_var_list;
/**
* \brief update output_var_list with output_var_map, output_var_map_id
*
*/
MGE_WIN_DECLSPEC_FUC void update_output_var_list(
const SymbolVarArray& output_var_array);
/*! /*!
* \brief call graph->compile() but also checks for comp seq rec * \brief call graph->compile() but also checks for comp seq rec
* *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册