diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 3b7d0ef68e053bd0daea1c01abe9932b8279f110..0df72d400089550457d26cfc6c8be131ffdc3e93 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -67,10 +67,15 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { void NetworkImplDft::application_config() { auto device_type = m_user_config->device_type; m_compnode_locator.type = to_compnode_locator(device_type).type; - m_compnode_locator.device = m_user_config->device_id; + //! when the device id is not configured, configure it + if (m_compnode_locator.device == -1) { + m_compnode_locator.device = m_user_config->device_id; + } if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) { m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; - m_compnode_locator.device = m_user_config->device_id; + if (m_compnode_locator.device == -1) { + m_compnode_locator.device = m_user_config->device_id; + } } //! model options #define ConfigOption(mge_name, lite_name) \ @@ -155,11 +160,13 @@ void NetworkImplDft::set_cpu_inplace_mode() { m_is_cpu_inplace_mode = true; if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) { m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; + m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; } else { LITE_ASSERT( m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD, "cpu inplace mode is only avaliable in CPU."); m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; + m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; } } @@ -170,6 +177,12 @@ void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) { if (nr_threads > 1) { m_nr_threads = nr_threads; m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; + if (m_is_cpu_inplace_mode) { + m_compnode_locator.device = + mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; + m_user_config->device_id = + mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; + } m_compnode_locator.nr_threads = nr_threads; } } diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 707fc19d4711ce2c65628a50edfcbf9fdaf89500..8b57e45c56dd51c19bd1ccfe2c6deb9f749f2534 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -216,6 +216,57 @@ TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { compare_lite_tensor(output_tensor, result_mgb); } +namespace { +void test_multi_thread(bool multi_thread_compnode) { + Config config; + auto lite_tensor = get_input_data("./input_data.npy"); + std::string model_path = "./shufflenet.mge"; + + size_t nr_threads = 2; + std::vector thread_ids(nr_threads); + auto runner = [&](size_t i) { + std::shared_ptr network = std::make_shared(config); + Runtime::set_cpu_inplace_mode(network); + if (multi_thread_compnode) { + Runtime::set_cpu_threads_number(network, 2); + } + + network->load_model(model_path); + Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) { + if (id == 0) { + thread_ids[i] = std::this_thread::get_id(); + } + }); + + std::shared_ptr input_tensor = network->get_input_tensor(0); + auto src_ptr = lite_tensor->get_memory_ptr(); + auto src_layout = lite_tensor->get_layout(); + input_tensor->reset(src_ptr, src_layout); + + network->forward(); + network->wait(); + std::shared_ptr output_tensor = network->get_output_tensor(0); + }; + std::vector threads; + for (size_t i = 0; i < nr_threads; i++) { + threads.emplace_back(runner, i); + } + for (size_t i = 0; i < nr_threads; i++) { + threads[i].join(); + } + ASSERT_NE(thread_ids[0], thread_ids[1]); +} + +} // namespace + +TEST(TestNetWork, InplaceAndUserMultithreadThread) { + test_multi_thread(false); +} + +TEST(TestNetWork, InplaceAndMultithread) { + test_multi_thread(true); +} + TEST(TestNetWork, NetworkShareWeights) { Config config; auto lite_tensor = get_input_data("./input_data.npy");