提交 e65e3f05 编写于 作者: M Megvii Engine Team

fix(lite/load_and_run): fix some bugs of load and run

GitOrigin-RevId: ffd578b97b458454b434c21b78f68bc2cafc63d6
上级 24f12df9
......@@ -39,7 +39,7 @@ void OutputDumper::write_to_file() {
info.owner_inputs_info.c_str()));
mgb::debug::write_to_file(
mgb::ssprintf(
"%s/run%zu-var %zd", dump_file.c_str(), m_run_id, info.id)
"%s/run%zu-var%zd", dump_file.c_str(), m_run_id, info.id)
.c_str(),
value);
}
......
......@@ -40,7 +40,7 @@ public:
void wait() override;
//! get the network of lite model
std::shared_ptr<lite::Network> get_lite_network() { return m_network; }
std::shared_ptr<lite::Network>& get_lite_network() { return m_network; }
//! get the config of lite model
lite::Config& get_config() { return config; }
......
......@@ -67,13 +67,16 @@ public:
//! get data parser
DataParser& get_input_parser() { return parser; }
uint32_t get_testcase_num() { return testcase_num; }
std::vector<std::pair<std::string, mgb::HostTensorND*>>& get_test_input() {
return test_input_tensors;
}
//! get output specified configuration
mgb::ComputingGraph::OutputSpec& get_output_spec() { return m_output_spec; }
std::unique_ptr<mgb::cg::AsyncExecutable>& get_async_func() { return m_asyc_exec; }
void set_output_callback(std::vector<mgb::ComputingGraph::Callback>& cb) {
......@@ -84,6 +87,7 @@ public:
m_callbacks[i] = cb[i];
}
}
#if MGB_ENABLE_JSON
std::unique_ptr<mgb::GraphProfiler>& get_profiler() { return m_profiler; }
void set_profiler() {
......@@ -91,6 +95,7 @@ public:
std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
}
#endif
void set_num_range_checker(float range) {
m_num_range_checker = std::make_unique<mgb::NumRangeChecker>(
m_load_config.comp_graph.get(), range);
......
......@@ -37,7 +37,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>(
}
#endif
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto network = model->get_lite_network();
auto&& network = model->get_lite_network();
if (enable_cpu_default) {
LITE_WARN("using cpu default device\n");
lite::Runtime::set_cpu_inplace_mode(network);
......
......@@ -55,8 +55,8 @@ void FastRunOption::config_model_internel<ModelLite>(
auto lite_strategy = static_cast<Strategy>(strategy);
model->set_lite_strategy(lite_strategy);
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto lite_network = model->get_lite_network();
auto lite_strategy = model->get_lite_strategy();
auto&& lite_network = model->get_lite_network();
auto&& lite_strategy = model->get_lite_strategy();
//! set algo policy for model
lite::Runtime::set_network_algo_policy(
lite_network, lite_strategy, share_batch_size, batch_binary_equal);
......@@ -121,8 +121,8 @@ void FastRunOption::config_model_internel<ModelMdl>(
.fast_run_config.shared_batch_size = share_batch_size;
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto vars = model->get_mdl_load_result().output_var_list;
auto strategy = model->get_mdl_strategy();
auto& vars = model->get_mdl_load_result().output_var_list;
auto&& strategy = model->get_mdl_strategy();
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
// set algo cache path
if (!m_fast_run_cache.empty()) {
......
......@@ -20,8 +20,8 @@ template <>
void InputOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto parser = model->get_input_parser();
auto io = model->get_networkIO();
auto&& parser = model->get_input_parser();
auto&& io = model->get_networkIO();
for (size_t idx = 0; idx < data_path.size(); ++idx) {
parser.feed(data_path[idx].c_str());
}
......@@ -32,9 +32,8 @@ void InputOption::config_model_internel<ModelLite>(
io.inputs.push_back({i.first, is_host});
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto config = model->get_config();
auto parser = model->get_input_parser();
auto network = model->get_lite_network();
auto&& parser = model->get_input_parser();
auto&& network = model->get_lite_network();
//! datd type map from mgb data type to lite data type
std::map<megdnn::DTypeEnum, LiteDataType> type_map = {
......@@ -75,8 +74,8 @@ void InputOption::config_model_internel<ModelMdl>(
parser.feed(data_path[idx].c_str());
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto parser = model->get_input_parser();
auto network = model->get_mdl_load_result();
auto&& parser = model->get_input_parser();
auto&& network = model->get_mdl_load_result();
auto tensormap = network.tensor_map;
for (auto& i : parser.inputs) {
mgb_assert(
......@@ -156,7 +155,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (enable_bin_out_dump) {
auto load_result = model->get_mdl_load_result();
auto&& load_result = model->get_mdl_load_result();
out_dumper->set(load_result.output_var_list);
std::vector<mgb::ComputingGraph::Callback> cb;
......@@ -166,7 +165,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
model->set_output_callback(cb);
}
if (enable_copy_to_host) {
auto load_result = model->get_mdl_load_result();
auto&& load_result = model->get_mdl_load_result();
std::vector<mgb::ComputingGraph::Callback> cb;
for (size_t i = 0; i < load_result.output_var_list.size(); i++) {
......
......@@ -365,7 +365,7 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>(
}
if (workspace_limit < SIZE_MAX) {
mgb_log_warn("set workspace limit to %ld", workspace_limit);
auto output_spec = model->get_output_spec();
auto&& output_spec = model->get_output_spec();
mgb::SymbolVarArray vars;
for (auto i : output_spec) {
vars.push_back(i.first);
......
......@@ -46,7 +46,7 @@ template <>
void PluginOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto config = model->get_mdl_config();
auto&& config = model->get_mdl_config();
if (range > 0) {
mgb_log_warn("enable number range check");
model->set_num_range_checker(float(range));
......@@ -151,7 +151,7 @@ template <>
void DebugOption::format_and_print(
const std::string& tablename, std::shared_ptr<ModelLite> model) {
auto table = mgb::TextTable(tablename);
auto network = model->get_lite_network();
auto&& network = model->get_lite_network();
table.padding(1);
table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor();
......@@ -259,7 +259,6 @@ template <>
void DebugOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto config = model->get_mdl_config();
if (enable_verbose) {
mgb_log_warn("enable verbose");
mgb::set_log_level(mgb::LogLevel::DEBUG);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册