From e65e3f0579b096df518d56c7b417ef322aee14d0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 29 Nov 2021 10:48:26 +0800 Subject: [PATCH] fix(lite/load_and_run): fix some bugs of load and run GitOrigin-RevId: ffd578b97b458454b434c21b78f68bc2cafc63d6 --- lite/load_and_run/src/helpers/outdumper.cpp | 2 +- lite/load_and_run/src/models/model_lite.h | 2 +- lite/load_and_run/src/models/model_mdl.h | 5 +++++ .../load_and_run/src/options/device_options.cpp | 2 +- .../src/options/fastrun_options.cpp | 8 ++++---- lite/load_and_run/src/options/io_options.cpp | 17 ++++++++--------- .../src/options/optimize_options.cpp | 2 +- .../load_and_run/src/options/plugin_options.cpp | 5 ++--- 8 files changed, 23 insertions(+), 20 deletions(-) diff --git a/lite/load_and_run/src/helpers/outdumper.cpp b/lite/load_and_run/src/helpers/outdumper.cpp index 9a5d8315b..7fb90c423 100644 --- a/lite/load_and_run/src/helpers/outdumper.cpp +++ b/lite/load_and_run/src/helpers/outdumper.cpp @@ -39,7 +39,7 @@ void OutputDumper::write_to_file() { info.owner_inputs_info.c_str())); mgb::debug::write_to_file( mgb::ssprintf( - "%s/run%zu-var %zd", dump_file.c_str(), m_run_id, info.id) + "%s/run%zu-var%zd", dump_file.c_str(), m_run_id, info.id) .c_str(), value); } diff --git a/lite/load_and_run/src/models/model_lite.h b/lite/load_and_run/src/models/model_lite.h index 66e7aa983..dc22fc7dd 100644 --- a/lite/load_and_run/src/models/model_lite.h +++ b/lite/load_and_run/src/models/model_lite.h @@ -40,7 +40,7 @@ public: void wait() override; //! get the network of lite model - std::shared_ptr get_lite_network() { return m_network; } + std::shared_ptr& get_lite_network() { return m_network; } //! get the config of lite model lite::Config& get_config() { return config; } diff --git a/lite/load_and_run/src/models/model_mdl.h b/lite/load_and_run/src/models/model_mdl.h index 59d27bd91..07211e46b 100644 --- a/lite/load_and_run/src/models/model_mdl.h +++ b/lite/load_and_run/src/models/model_mdl.h @@ -67,13 +67,16 @@ public: //! get data parser DataParser& get_input_parser() { return parser; } + uint32_t get_testcase_num() { return testcase_num; } + std::vector>& get_test_input() { return test_input_tensors; } //! get output specified configuration mgb::ComputingGraph::OutputSpec& get_output_spec() { return m_output_spec; } + std::unique_ptr& get_async_func() { return m_asyc_exec; } void set_output_callback(std::vector& cb) { @@ -84,6 +87,7 @@ public: m_callbacks[i] = cb[i]; } } + #if MGB_ENABLE_JSON std::unique_ptr& get_profiler() { return m_profiler; } void set_profiler() { @@ -91,6 +95,7 @@ public: std::make_unique(m_load_config.comp_graph.get()); } #endif + void set_num_range_checker(float range) { m_num_range_checker = std::make_unique( m_load_config.comp_graph.get(), range); diff --git a/lite/load_and_run/src/options/device_options.cpp b/lite/load_and_run/src/options/device_options.cpp index 3365d8bc2..bc1825060 100644 --- a/lite/load_and_run/src/options/device_options.cpp +++ b/lite/load_and_run/src/options/device_options.cpp @@ -37,7 +37,7 @@ void XPUDeviceOption::config_model_internel( } #endif } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { - auto network = model->get_lite_network(); + auto&& network = model->get_lite_network(); if (enable_cpu_default) { LITE_WARN("using cpu default device\n"); lite::Runtime::set_cpu_inplace_mode(network); diff --git a/lite/load_and_run/src/options/fastrun_options.cpp b/lite/load_and_run/src/options/fastrun_options.cpp index 764bfeb90..7bcb728f5 100644 --- a/lite/load_and_run/src/options/fastrun_options.cpp +++ b/lite/load_and_run/src/options/fastrun_options.cpp @@ -55,8 +55,8 @@ void FastRunOption::config_model_internel( auto lite_strategy = static_cast(strategy); model->set_lite_strategy(lite_strategy); } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { - auto lite_network = model->get_lite_network(); - auto lite_strategy = model->get_lite_strategy(); + auto&& lite_network = model->get_lite_network(); + auto&& lite_strategy = model->get_lite_strategy(); //! set algo policy for model lite::Runtime::set_network_algo_policy( lite_network, lite_strategy, share_batch_size, batch_binary_equal); @@ -121,8 +121,8 @@ void FastRunOption::config_model_internel( .fast_run_config.shared_batch_size = share_batch_size; } } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { - auto vars = model->get_mdl_load_result().output_var_list; - auto strategy = model->get_mdl_strategy(); + auto& vars = model->get_mdl_load_result().output_var_list; + auto&& strategy = model->get_mdl_strategy(); mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); // set algo cache path if (!m_fast_run_cache.empty()) { diff --git a/lite/load_and_run/src/options/io_options.cpp b/lite/load_and_run/src/options/io_options.cpp index 961ca99cc..612aa073f 100644 --- a/lite/load_and_run/src/options/io_options.cpp +++ b/lite/load_and_run/src/options/io_options.cpp @@ -20,8 +20,8 @@ template <> void InputOption::config_model_internel( RuntimeParam& runtime_param, std::shared_ptr model) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { - auto parser = model->get_input_parser(); - auto io = model->get_networkIO(); + auto&& parser = model->get_input_parser(); + auto&& io = model->get_networkIO(); for (size_t idx = 0; idx < data_path.size(); ++idx) { parser.feed(data_path[idx].c_str()); } @@ -32,9 +32,8 @@ void InputOption::config_model_internel( io.inputs.push_back({i.first, is_host}); } } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { - auto config = model->get_config(); - auto parser = model->get_input_parser(); - auto network = model->get_lite_network(); + auto&& parser = model->get_input_parser(); + auto&& network = model->get_lite_network(); //! datd type map from mgb data type to lite data type std::map type_map = { @@ -75,8 +74,8 @@ void InputOption::config_model_internel( parser.feed(data_path[idx].c_str()); } } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { - auto parser = model->get_input_parser(); - auto network = model->get_mdl_load_result(); + auto&& parser = model->get_input_parser(); + auto&& network = model->get_mdl_load_result(); auto tensormap = network.tensor_map; for (auto& i : parser.inputs) { mgb_assert( @@ -156,7 +155,7 @@ void IOdumpOption::config_model_internel( } } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { if (enable_bin_out_dump) { - auto load_result = model->get_mdl_load_result(); + auto&& load_result = model->get_mdl_load_result(); out_dumper->set(load_result.output_var_list); std::vector cb; @@ -166,7 +165,7 @@ void IOdumpOption::config_model_internel( model->set_output_callback(cb); } if (enable_copy_to_host) { - auto load_result = model->get_mdl_load_result(); + auto&& load_result = model->get_mdl_load_result(); std::vector cb; for (size_t i = 0; i < load_result.output_var_list.size(); i++) { diff --git a/lite/load_and_run/src/options/optimize_options.cpp b/lite/load_and_run/src/options/optimize_options.cpp index c684a3afa..f600aa131 100644 --- a/lite/load_and_run/src/options/optimize_options.cpp +++ b/lite/load_and_run/src/options/optimize_options.cpp @@ -365,7 +365,7 @@ void MemoryOptimizeOption::config_model_internel( } if (workspace_limit < SIZE_MAX) { mgb_log_warn("set workspace limit to %ld", workspace_limit); - auto output_spec = model->get_output_spec(); + auto&& output_spec = model->get_output_spec(); mgb::SymbolVarArray vars; for (auto i : output_spec) { vars.push_back(i.first); diff --git a/lite/load_and_run/src/options/plugin_options.cpp b/lite/load_and_run/src/options/plugin_options.cpp index 8b9668bdd..a3d622a6b 100644 --- a/lite/load_and_run/src/options/plugin_options.cpp +++ b/lite/load_and_run/src/options/plugin_options.cpp @@ -46,7 +46,7 @@ template <> void PluginOption::config_model_internel( RuntimeParam& runtime_param, std::shared_ptr model) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { - auto config = model->get_mdl_config(); + auto&& config = model->get_mdl_config(); if (range > 0) { mgb_log_warn("enable number range check"); model->set_num_range_checker(float(range)); @@ -151,7 +151,7 @@ template <> void DebugOption::format_and_print( const std::string& tablename, std::shared_ptr model) { auto table = mgb::TextTable(tablename); - auto network = model->get_lite_network(); + auto&& network = model->get_lite_network(); table.padding(1); table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); @@ -259,7 +259,6 @@ template <> void DebugOption::config_model_internel( RuntimeParam& runtime_param, std::shared_ptr model) { if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { - auto config = model->get_mdl_config(); if (enable_verbose) { mgb_log_warn("enable verbose"); mgb::set_log_level(mgb::LogLevel::DEBUG); -- GitLab