From d1d8ddeeac04d61307a657be933fb575ac844e48 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 18 Aug 2022 15:02:31 +0800 Subject: [PATCH] fix(lite): fix lar multithread options invalid GitOrigin-RevId: 55f83036b1b2ae4124385d8e3d839e6b95bfd809 --- lite/load_and_run/src/options/device_options.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lite/load_and_run/src/options/device_options.cpp b/lite/load_and_run/src/options/device_options.cpp index b9208def8..66be27611 100644 --- a/lite/load_and_run/src/options/device_options.cpp +++ b/lite/load_and_run/src/options/device_options.cpp @@ -28,7 +28,7 @@ void XPUDeviceOption::config_model_internel( model->get_config().device_type = LiteDeviceType::LITE_CUDA; } #endif - } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + } else if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) { auto&& network = model->get_lite_network(); if (enable_cpu_default) { LITE_LOG("using cpu default device\n"); @@ -86,7 +86,7 @@ void XPUDeviceOption::config_model_internel( }; } if (enable_multithread) { - mgb_log("using multithread device\n"); + mgb_log("using multithread(threads number:%ld) device\n", thread_num); model->get_mdl_config().comp_node_mapper = [&](mgb::CompNode::Locator& loc) { loc.type = mgb::CompNode::DeviceType::MULTITHREAD; @@ -217,11 +217,15 @@ void XPUDeviceOption::config_model( std::static_pointer_cast(m_option["multithread"]) ->get_value(); enable_multithread = num_of_thread >= 0; - num_of_thread = + int32_t num_of_thread_dft = std::static_pointer_cast(m_option["multithread_default"]) ->get_value(); - enable_multithread_default = num_of_thread >= 0; - thread_num = num_of_thread >= 0 ? num_of_thread : 0; + enable_multithread_default = num_of_thread_dft >= 0; + mgb_assert( + num_of_thread < 0 || num_of_thread_dft < 0, + "multithread and multithread_default should not bet set at the same time"); + thread_num = num_of_thread >= 0 ? num_of_thread + : (num_of_thread_dft >= 0 ? num_of_thread_dft : -1); std::string core_id_str = std::static_pointer_cast(m_option["multi_thread_core_ids"]) ->get_value(); -- GitLab