提交 8461c8d8 编写于 作者: M Megvii Engine Team

fix(lite): fix ldr use lite interface error when open both fast-run and nchw44

GitOrigin-RevId: 27b29d60af17c61bc52f06770f76bf9227647605
上级 43bd949a
......@@ -11,27 +11,29 @@ enum class RunStage {
BEFORE_MODEL_LOAD = 0,
AFTER_MODEL_LOAD = 1,
AFTER_NETWORK_CREATED = 1,
BEFORE_OUTSPEC_SET = 2,
AFTER_MODEL_LOAD = 2,
BEFORE_OUTSPEC_SET = 3,
//! using for dump static memory information svg file
AFTER_OUTSPEC_SET = 3,
AFTER_OUTSPEC_SET = 4,
//! using for external c opr library
MODEL_RUNNING = 4,
MODEL_RUNNING = 5,
//! using for output dumper
AFTER_RUNNING_WAIT = 5,
AFTER_RUNNING_WAIT = 6,
//! using for external c opr library
AFTER_RUNNING_ITER = 6,
AFTER_RUNNING_ITER = 7,
AFTER_MODEL_RUNNING = 7,
AFTER_MODEL_RUNNING = 8,
GLOBAL_OPTIMIZATION = 8,
GLOBAL_OPTIMIZATION = 9,
UPDATE_IO = 9,
UPDATE_IO = 10,
};
/*!
* \brief: type of different model
......
......@@ -24,6 +24,8 @@ public:
virtual void set_shared_mem(bool state) = 0;
virtual void create_network(){};
//! load model interface for load and run strategy
virtual void load_model() = 0;
......
......@@ -10,12 +10,12 @@ using namespace lar;
ModelLite::ModelLite(const std::string& path) : model_path(path) {
LITE_LOG("creat lite model use CPU as default comp node");
};
void ModelLite::load_model() {
void ModelLite::create_network() {
m_network = std::make_shared<lite::Network>(config, IO);
if (enable_layout_transform) {
LITE_LOG("enable layout transform while load model for lite");
lite::Runtime::enable_global_layout_transform(m_network);
}
}
void ModelLite::load_model() {
if (share_model_mem) {
//! WARNNING:maybe not right to share param memmory for this
LITE_LOG("enable share model memory");
......
......@@ -21,6 +21,9 @@ public:
//! set to load from shared memory
void set_shared_mem(bool state) override { share_model_mem = state; }
//! load model from dump file
void create_network() override;
//! load model from dump file
void load_model() override;
......@@ -34,9 +37,6 @@ public:
std::shared_ptr<mgb::json::Object> get_io_info() override;
#endif
//! enable global layout transform
void set_layout_transform(bool state) { enable_layout_transform = state; }
//! get the network of lite model
std::shared_ptr<lite::Network>& get_lite_network() { return m_network; }
......@@ -61,7 +61,6 @@ public:
private:
bool share_model_mem = false;
bool enable_layout_transform = false;
std::string model_path;
DataParser parser;
......
......@@ -19,7 +19,7 @@ namespace lar {
template <>
void FastRunOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
//! set the algo policy before model load
using Strategy = ModelLite::Strategy;
uint32_t strategy = 0;
......@@ -44,23 +44,17 @@ void FastRunOption::config_model_internel<ModelLite>(
strategy;
}
auto lite_strategy = static_cast<Strategy>(strategy);
model->set_lite_strategy(lite_strategy);
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto&& lite_network = model->get_lite_network();
auto&& lite_strategy = model->get_lite_strategy();
//! set algo policy for model
auto&& lite_network = model->get_lite_network();
lite::Runtime::set_network_algo_policy(
lite_network, lite_strategy, share_batch_size, batch_binary_equal);
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (!m_fast_run_cache.empty()) {
if (!access(m_fast_run_cache.c_str(), F_OK)) {
lite::set_persistent_cache(m_fast_run_cache);
} else {
lite::set_persistent_cache(m_fast_run_cache, true);
}
//! TODO:this is from mdl model settings but not matched settings in
//! lite model
// if (!enable_full_run && !enable_fast_run)
// mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
#if MGB_ENABLE_FASTRUN
......
......@@ -9,7 +9,7 @@ namespace lar {
template <>
void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
if (m_layout_transform) {
LITE_LOG("using global layout transform optimization\n");
if (m_layout_transform_target ==
......@@ -23,7 +23,9 @@ void GoptLayoutOption::config_model_internel<ModelLite>(
model->get_config().device_type = LiteDeviceType::LITE_CUDA;
}
#endif
model->set_layout_transform(true);
LITE_LOG("enable layout transform while load model for lite");
auto&& lite_network = model->get_lite_network();
lite::Runtime::enable_global_layout_transform(lite_network);
}
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
......
......@@ -197,6 +197,10 @@ void OptionsTimeProfiler::profile_with_given_options(
runtime_param.stage = RunStage::BEFORE_MODEL_LOAD;
stage_config_model();
runtime_param.stage = RunStage::AFTER_NETWORK_CREATED;
model->create_network();
stage_config_model();
model->load_model();
//! after load configure
auto config_model_before_runing = [&]() {
......
......@@ -42,6 +42,10 @@ void NormalStrategy::run_subline() {
m_runtime_param.stage = RunStage::BEFORE_MODEL_LOAD;
stage_config_model();
m_runtime_param.stage = RunStage::AFTER_NETWORK_CREATED;
model->create_network();
stage_config_model();
mgb::RealTimer timer;
model->load_model();
mgb_log("load model: %.3fms\n", timer.get_msecs_reset());
......
......@@ -18,6 +18,7 @@ DECLARE_bool(enable_nchw32);
DECLARE_bool(enable_nchw64);
DECLARE_bool(enable_nhwcd4);
DECLARE_bool(enable_nchw44_dot);
DECLARE_bool(fast_run);
namespace {
BOOL_OPTION_WRAP(enable_nchw4);
BOOL_OPTION_WRAP(enable_chwn4);
......@@ -27,6 +28,7 @@ BOOL_OPTION_WRAP(enable_nchw32);
BOOL_OPTION_WRAP(enable_nchw64);
BOOL_OPTION_WRAP(enable_nhwcd4);
BOOL_OPTION_WRAP(enable_nchw44_dot);
BOOL_OPTION_WRAP(fast_run);
BOOL_OPTION_WRAP(lite);
BOOL_OPTION_WRAP(cpu);
......@@ -60,6 +62,17 @@ TEST(TestLarLayout, X86_CPU_LITE) {
TEST_BOOL_OPTION(enable_nchw32);
TEST_BOOL_OPTION(enable_nchw88);
}
TEST(TestLarLayoutFastRun, CPU_LITE) {
DEFINE_WRAP(cpu);
DEFINE_WRAP(lite);
std::string model_path = "./shufflenet.mge";
{
DEFINE_WRAP(enable_nchw44);
DEFINE_WRAP(fast_run);
run_NormalStrategy(model_path);
}
}
#if LITE_WITH_CUDA
TEST(TestLarLayout, CUDA) {
DEFINE_WRAP(cuda);
......
......@@ -27,7 +27,7 @@ void run_NormalStrategy(std::string model_path);
#define TEST_BOOL_OPTION(option) \
{ \
BoolOptionWrap_##option flags_##option; \
DEFINE_WRAP(option); \
run_NormalStrategy(model_path); \
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册