提交 d4bf57d6 编写于 作者: M Megvii Engine Team

fix(lite): fix lar multithread options invalid

GitOrigin-RevId: 55f83036b1b2ae4124385d8e3d839e6b95bfd809
上级 1404437a
...@@ -28,7 +28,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>( ...@@ -28,7 +28,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>(
model->get_config().device_type = LiteDeviceType::LITE_CUDA; model->get_config().device_type = LiteDeviceType::LITE_CUDA;
} }
#endif #endif
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { } else if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
auto&& network = model->get_lite_network(); auto&& network = model->get_lite_network();
if (enable_cpu_default) { if (enable_cpu_default) {
LITE_LOG("using cpu default device\n"); LITE_LOG("using cpu default device\n");
...@@ -86,7 +86,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>( ...@@ -86,7 +86,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
}; };
} }
if (enable_multithread) { if (enable_multithread) {
mgb_log("using multithread device\n"); mgb_log("using multithread(threads number:%ld) device\n", thread_num);
model->get_mdl_config().comp_node_mapper = model->get_mdl_config().comp_node_mapper =
[&](mgb::CompNode::Locator& loc) { [&](mgb::CompNode::Locator& loc) {
loc.type = mgb::CompNode::DeviceType::MULTITHREAD; loc.type = mgb::CompNode::DeviceType::MULTITHREAD;
...@@ -217,11 +217,15 @@ void XPUDeviceOption::config_model( ...@@ -217,11 +217,15 @@ void XPUDeviceOption::config_model(
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"]) std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"])
->get_value(); ->get_value();
enable_multithread = num_of_thread >= 0; enable_multithread = num_of_thread >= 0;
num_of_thread = int32_t num_of_thread_dft =
std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"]) std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"])
->get_value(); ->get_value();
enable_multithread_default = num_of_thread >= 0; enable_multithread_default = num_of_thread_dft >= 0;
thread_num = num_of_thread >= 0 ? num_of_thread : 0; mgb_assert(
num_of_thread < 0 || num_of_thread_dft < 0,
"multithread and multithread_default should not bet set at the same time");
thread_num = num_of_thread >= 0 ? num_of_thread
: (num_of_thread_dft >= 0 ? num_of_thread_dft : -1);
std::string core_id_str = std::string core_id_str =
std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"]) std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"])
->get_value(); ->get_value();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册