提交 8461c8d8 编写于 作者: M Megvii Engine Team

fix(lite): fix ldr use lite interface error when open both fast-run and nchw44

GitOrigin-RevId: 27b29d60af17c61bc52f06770f76bf9227647605
上级 43bd949a
...@@ -11,27 +11,29 @@ enum class RunStage { ...@@ -11,27 +11,29 @@ enum class RunStage {
BEFORE_MODEL_LOAD = 0, BEFORE_MODEL_LOAD = 0,
AFTER_MODEL_LOAD = 1, AFTER_NETWORK_CREATED = 1,
BEFORE_OUTSPEC_SET = 2, AFTER_MODEL_LOAD = 2,
BEFORE_OUTSPEC_SET = 3,
//! using for dump static memory information svg file //! using for dump static memory information svg file
AFTER_OUTSPEC_SET = 3, AFTER_OUTSPEC_SET = 4,
//! using for external c opr library //! using for external c opr library
MODEL_RUNNING = 4, MODEL_RUNNING = 5,
//! using for output dumper //! using for output dumper
AFTER_RUNNING_WAIT = 5, AFTER_RUNNING_WAIT = 6,
//! using for external c opr library //! using for external c opr library
AFTER_RUNNING_ITER = 6, AFTER_RUNNING_ITER = 7,
AFTER_MODEL_RUNNING = 7, AFTER_MODEL_RUNNING = 8,
GLOBAL_OPTIMIZATION = 8, GLOBAL_OPTIMIZATION = 9,
UPDATE_IO = 9, UPDATE_IO = 10,
}; };
/*! /*!
* \brief: type of different model * \brief: type of different model
......
...@@ -24,6 +24,8 @@ public: ...@@ -24,6 +24,8 @@ public:
virtual void set_shared_mem(bool state) = 0; virtual void set_shared_mem(bool state) = 0;
virtual void create_network(){};
//! load model interface for load and run strategy //! load model interface for load and run strategy
virtual void load_model() = 0; virtual void load_model() = 0;
......
...@@ -10,12 +10,12 @@ using namespace lar; ...@@ -10,12 +10,12 @@ using namespace lar;
ModelLite::ModelLite(const std::string& path) : model_path(path) { ModelLite::ModelLite(const std::string& path) : model_path(path) {
LITE_LOG("creat lite model use CPU as default comp node"); LITE_LOG("creat lite model use CPU as default comp node");
}; };
void ModelLite::load_model() {
void ModelLite::create_network() {
m_network = std::make_shared<lite::Network>(config, IO); m_network = std::make_shared<lite::Network>(config, IO);
if (enable_layout_transform) { }
LITE_LOG("enable layout transform while load model for lite");
lite::Runtime::enable_global_layout_transform(m_network); void ModelLite::load_model() {
}
if (share_model_mem) { if (share_model_mem) {
//! WARNNING:maybe not right to share param memmory for this //! WARNNING:maybe not right to share param memmory for this
LITE_LOG("enable share model memory"); LITE_LOG("enable share model memory");
...@@ -116,4 +116,4 @@ std::vector<uint8_t> ModelLite::get_model_data() { ...@@ -116,4 +116,4 @@ std::vector<uint8_t> ModelLite::get_model_data() {
LITE_THROW("unsupported interface: ModelLite::get_model_data() \n"); LITE_THROW("unsupported interface: ModelLite::get_model_data() \n");
return out_data; return out_data;
} }
\ No newline at end of file
...@@ -21,6 +21,9 @@ public: ...@@ -21,6 +21,9 @@ public:
//! set to load from shared memory //! set to load from shared memory
void set_shared_mem(bool state) override { share_model_mem = state; } void set_shared_mem(bool state) override { share_model_mem = state; }
//! load model from dump file
void create_network() override;
//! load model from dump file //! load model from dump file
void load_model() override; void load_model() override;
...@@ -34,9 +37,6 @@ public: ...@@ -34,9 +37,6 @@ public:
std::shared_ptr<mgb::json::Object> get_io_info() override; std::shared_ptr<mgb::json::Object> get_io_info() override;
#endif #endif
//! enable global layout transform
void set_layout_transform(bool state) { enable_layout_transform = state; }
//! get the network of lite model //! get the network of lite model
std::shared_ptr<lite::Network>& get_lite_network() { return m_network; } std::shared_ptr<lite::Network>& get_lite_network() { return m_network; }
...@@ -61,7 +61,6 @@ public: ...@@ -61,7 +61,6 @@ public:
private: private:
bool share_model_mem = false; bool share_model_mem = false;
bool enable_layout_transform = false;
std::string model_path; std::string model_path;
DataParser parser; DataParser parser;
......
...@@ -19,7 +19,7 @@ namespace lar { ...@@ -19,7 +19,7 @@ namespace lar {
template <> template <>
void FastRunOption::config_model_internel<ModelLite>( void FastRunOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
//! set the algo policy before model load //! set the algo policy before model load
using Strategy = ModelLite::Strategy; using Strategy = ModelLite::Strategy;
uint32_t strategy = 0; uint32_t strategy = 0;
...@@ -44,23 +44,17 @@ void FastRunOption::config_model_internel<ModelLite>( ...@@ -44,23 +44,17 @@ void FastRunOption::config_model_internel<ModelLite>(
strategy; strategy;
} }
auto lite_strategy = static_cast<Strategy>(strategy); auto lite_strategy = static_cast<Strategy>(strategy);
model->set_lite_strategy(lite_strategy);
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto&& lite_network = model->get_lite_network();
auto&& lite_strategy = model->get_lite_strategy();
//! set algo policy for model //! set algo policy for model
auto&& lite_network = model->get_lite_network();
lite::Runtime::set_network_algo_policy( lite::Runtime::set_network_algo_policy(
lite_network, lite_strategy, share_batch_size, batch_binary_equal); lite_network, lite_strategy, share_batch_size, batch_binary_equal);
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (!m_fast_run_cache.empty()) { if (!m_fast_run_cache.empty()) {
if (!access(m_fast_run_cache.c_str(), F_OK)) { if (!access(m_fast_run_cache.c_str(), F_OK)) {
lite::set_persistent_cache(m_fast_run_cache); lite::set_persistent_cache(m_fast_run_cache);
} else { } else {
lite::set_persistent_cache(m_fast_run_cache, true); lite::set_persistent_cache(m_fast_run_cache, true);
} }
//! TODO:this is from mdl model settings but not matched settings in
//! lite model
// if (!enable_full_run && !enable_fast_run)
// mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
} }
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
...@@ -255,4 +249,4 @@ DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fast ...@@ -255,4 +249,4 @@ DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fast
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); DEFINE_string(fast_run_algo_policy, "", "fast-run cache path.");
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option);
REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid); REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid);
\ No newline at end of file
...@@ -9,7 +9,7 @@ namespace lar { ...@@ -9,7 +9,7 @@ namespace lar {
template <> template <>
void GoptLayoutOption::config_model_internel<ModelLite>( void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
if (m_layout_transform) { if (m_layout_transform) {
LITE_LOG("using global layout transform optimization\n"); LITE_LOG("using global layout transform optimization\n");
if (m_layout_transform_target == if (m_layout_transform_target ==
...@@ -23,7 +23,9 @@ void GoptLayoutOption::config_model_internel<ModelLite>( ...@@ -23,7 +23,9 @@ void GoptLayoutOption::config_model_internel<ModelLite>(
model->get_config().device_type = LiteDeviceType::LITE_CUDA; model->get_config().device_type = LiteDeviceType::LITE_CUDA;
} }
#endif #endif
model->set_layout_transform(true); LITE_LOG("enable layout transform while load model for lite");
auto&& lite_network = model->get_lite_network();
lite::Runtime::enable_global_layout_transform(lite_network);
} }
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) { } else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) { if (m_layout_transform) {
...@@ -266,4 +268,4 @@ DEFINE_int32( ...@@ -266,4 +268,4 @@ DEFINE_int32(
layout_transform_batch_size, -1, layout_transform_batch_size, -1,
"the batch size of input for global layout transform optimization working on"); "the batch size of input for global layout transform optimization working on");
REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option); REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option);
REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid); REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid);
\ No newline at end of file
...@@ -197,6 +197,10 @@ void OptionsTimeProfiler::profile_with_given_options( ...@@ -197,6 +197,10 @@ void OptionsTimeProfiler::profile_with_given_options(
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; runtime_param.stage = RunStage::BEFORE_MODEL_LOAD;
stage_config_model(); stage_config_model();
runtime_param.stage = RunStage::AFTER_NETWORK_CREATED;
model->create_network();
stage_config_model();
model->load_model(); model->load_model();
//! after load configure //! after load configure
auto config_model_before_runing = [&]() { auto config_model_before_runing = [&]() {
......
...@@ -42,6 +42,10 @@ void NormalStrategy::run_subline() { ...@@ -42,6 +42,10 @@ void NormalStrategy::run_subline() {
m_runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; m_runtime_param.stage = RunStage::BEFORE_MODEL_LOAD;
stage_config_model(); stage_config_model();
m_runtime_param.stage = RunStage::AFTER_NETWORK_CREATED;
model->create_network();
stage_config_model();
mgb::RealTimer timer; mgb::RealTimer timer;
model->load_model(); model->load_model();
mgb_log("load model: %.3fms\n", timer.get_msecs_reset()); mgb_log("load model: %.3fms\n", timer.get_msecs_reset());
......
...@@ -18,6 +18,7 @@ DECLARE_bool(enable_nchw32); ...@@ -18,6 +18,7 @@ DECLARE_bool(enable_nchw32);
DECLARE_bool(enable_nchw64); DECLARE_bool(enable_nchw64);
DECLARE_bool(enable_nhwcd4); DECLARE_bool(enable_nhwcd4);
DECLARE_bool(enable_nchw44_dot); DECLARE_bool(enable_nchw44_dot);
DECLARE_bool(fast_run);
namespace { namespace {
BOOL_OPTION_WRAP(enable_nchw4); BOOL_OPTION_WRAP(enable_nchw4);
BOOL_OPTION_WRAP(enable_chwn4); BOOL_OPTION_WRAP(enable_chwn4);
...@@ -27,6 +28,7 @@ BOOL_OPTION_WRAP(enable_nchw32); ...@@ -27,6 +28,7 @@ BOOL_OPTION_WRAP(enable_nchw32);
BOOL_OPTION_WRAP(enable_nchw64); BOOL_OPTION_WRAP(enable_nchw64);
BOOL_OPTION_WRAP(enable_nhwcd4); BOOL_OPTION_WRAP(enable_nhwcd4);
BOOL_OPTION_WRAP(enable_nchw44_dot); BOOL_OPTION_WRAP(enable_nchw44_dot);
BOOL_OPTION_WRAP(fast_run);
BOOL_OPTION_WRAP(lite); BOOL_OPTION_WRAP(lite);
BOOL_OPTION_WRAP(cpu); BOOL_OPTION_WRAP(cpu);
...@@ -60,6 +62,17 @@ TEST(TestLarLayout, X86_CPU_LITE) { ...@@ -60,6 +62,17 @@ TEST(TestLarLayout, X86_CPU_LITE) {
TEST_BOOL_OPTION(enable_nchw32); TEST_BOOL_OPTION(enable_nchw32);
TEST_BOOL_OPTION(enable_nchw88); TEST_BOOL_OPTION(enable_nchw88);
} }
TEST(TestLarLayoutFastRun, CPU_LITE) {
DEFINE_WRAP(cpu);
DEFINE_WRAP(lite);
std::string model_path = "./shufflenet.mge";
{
DEFINE_WRAP(enable_nchw44);
DEFINE_WRAP(fast_run);
run_NormalStrategy(model_path);
}
}
#if LITE_WITH_CUDA #if LITE_WITH_CUDA
TEST(TestLarLayout, CUDA) { TEST(TestLarLayout, CUDA) {
DEFINE_WRAP(cuda); DEFINE_WRAP(cuda);
......
...@@ -25,9 +25,9 @@ void run_NormalStrategy(std::string model_path); ...@@ -25,9 +25,9 @@ void run_NormalStrategy(std::string model_path);
#define DEFINE_WRAP(option) BoolOptionWrap_##option flags_##option; #define DEFINE_WRAP(option) BoolOptionWrap_##option flags_##option;
#define TEST_BOOL_OPTION(option) \ #define TEST_BOOL_OPTION(option) \
{ \ { \
BoolOptionWrap_##option flags_##option; \ DEFINE_WRAP(option); \
run_NormalStrategy(model_path); \ run_NormalStrategy(model_path); \
} }
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册