提交 9488cd1c 编写于 作者: M Megvii Engine Team 提交者: dengzheye

fix(lite): fix lite cpu default not work

GitOrigin-RevId: 8fc764623cacf3994be09a343c7560cbd933d15c
上级 518c7f37
...@@ -67,10 +67,15 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { ...@@ -67,10 +67,15 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
void NetworkImplDft::application_config() { void NetworkImplDft::application_config() {
auto device_type = m_user_config->device_type; auto device_type = m_user_config->device_type;
m_compnode_locator.type = to_compnode_locator(device_type).type; m_compnode_locator.type = to_compnode_locator(device_type).type;
m_compnode_locator.device = m_user_config->device_id; //! when the device id is not configured, configure it
if (m_compnode_locator.device == -1) {
m_compnode_locator.device = m_user_config->device_id;
}
if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) { if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
m_compnode_locator.device = m_user_config->device_id; if (m_compnode_locator.device == -1) {
m_compnode_locator.device = m_user_config->device_id;
}
} }
//! model options //! model options
#define ConfigOption(mge_name, lite_name) \ #define ConfigOption(mge_name, lite_name) \
...@@ -155,11 +160,13 @@ void NetworkImplDft::set_cpu_inplace_mode() { ...@@ -155,11 +160,13 @@ void NetworkImplDft::set_cpu_inplace_mode() {
m_is_cpu_inplace_mode = true; m_is_cpu_inplace_mode = true;
if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) { if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
} else { } else {
LITE_ASSERT( LITE_ASSERT(
m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD, m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
"cpu inplace mode is only avaliable in CPU."); "cpu inplace mode is only avaliable in CPU.");
m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
} }
} }
...@@ -170,6 +177,12 @@ void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) { ...@@ -170,6 +177,12 @@ void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
if (nr_threads > 1) { if (nr_threads > 1) {
m_nr_threads = nr_threads; m_nr_threads = nr_threads;
m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD; m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
if (m_is_cpu_inplace_mode) {
m_compnode_locator.device =
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
m_user_config->device_id =
mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
}
m_compnode_locator.nr_threads = nr_threads; m_compnode_locator.nr_threads = nr_threads;
} }
} }
......
...@@ -216,6 +216,57 @@ TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) { ...@@ -216,6 +216,57 @@ TEST(TestNetWork, BasicInplaceAndSingleThreadAffinity) {
compare_lite_tensor<float>(output_tensor, result_mgb); compare_lite_tensor<float>(output_tensor, result_mgb);
} }
namespace {
void test_multi_thread(bool multi_thread_compnode) {
Config config;
auto lite_tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
size_t nr_threads = 2;
std::vector<std::thread::id> thread_ids(nr_threads);
auto runner = [&](size_t i) {
std::shared_ptr<Network> network = std::make_shared<Network>(config);
Runtime::set_cpu_inplace_mode(network);
if (multi_thread_compnode) {
Runtime::set_cpu_threads_number(network, 2);
}
network->load_model(model_path);
Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) {
if (id == 0) {
thread_ids[i] = std::this_thread::get_id();
}
});
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
auto src_ptr = lite_tensor->get_memory_ptr();
auto src_layout = lite_tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
network->forward();
network->wait();
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nr_threads; i++) {
threads.emplace_back(runner, i);
}
for (size_t i = 0; i < nr_threads; i++) {
threads[i].join();
}
ASSERT_NE(thread_ids[0], thread_ids[1]);
}
} // namespace
TEST(TestNetWork, InplaceAndUserMultithreadThread) {
test_multi_thread(false);
}
TEST(TestNetWork, InplaceAndMultithread) {
test_multi_thread(true);
}
TEST(TestNetWork, NetworkShareWeights) { TEST(TestNetWork, NetworkShareWeights) {
Config config; Config config;
auto lite_tensor = get_input_data("./input_data.npy"); auto lite_tensor = get_input_data("./input_data.npy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册