diff --git a/lite/load_and_run/src/options/device_options.cpp b/lite/load_and_run/src/options/device_options.cpp index b9208def802eaaa2a2218212ed4ad5e7f292e577..66be27611830f06fe953987324ee76af02f5edd5 100644 --- a/lite/load_and_run/src/options/device_options.cpp +++ b/lite/load_and_run/src/options/device_options.cpp @@ -28,7 +28,7 @@ void XPUDeviceOption::config_model_internel( model->get_config().device_type = LiteDeviceType::LITE_CUDA; } #endif - } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + } else if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) { auto&& network = model->get_lite_network(); if (enable_cpu_default) { LITE_LOG("using cpu default device\n"); @@ -86,7 +86,7 @@ void XPUDeviceOption::config_model_internel( }; } if (enable_multithread) { - mgb_log("using multithread device\n"); + mgb_log("using multithread(threads number:%ld) device\n", thread_num); model->get_mdl_config().comp_node_mapper = [&](mgb::CompNode::Locator& loc) { loc.type = mgb::CompNode::DeviceType::MULTITHREAD; @@ -217,11 +217,15 @@ void XPUDeviceOption::config_model( std::static_pointer_cast(m_option["multithread"]) ->get_value(); enable_multithread = num_of_thread >= 0; - num_of_thread = + int32_t num_of_thread_dft = std::static_pointer_cast(m_option["multithread_default"]) ->get_value(); - enable_multithread_default = num_of_thread >= 0; - thread_num = num_of_thread >= 0 ? num_of_thread : 0; + enable_multithread_default = num_of_thread_dft >= 0; + mgb_assert( + num_of_thread < 0 || num_of_thread_dft < 0, + "multithread and multithread_default should not bet set at the same time"); + thread_num = num_of_thread >= 0 ? num_of_thread + : (num_of_thread_dft >= 0 ? num_of_thread_dft : -1); std::string core_id_str = std::static_pointer_cast(m_option["multi_thread_core_ids"]) ->get_value();