提交 6814cf1c 编写于 作者: M Megvii Engine Team

fix(lite): fix lite test error

GitOrigin-RevId: ab608672ecd9a4e17a03e6e7d5c9d809b3283d50
上级 7c8f1847
......@@ -223,8 +223,10 @@ void test_multi_thread(bool multi_thread_compnode) {
std::string model_path = "./shufflenet.mge";
size_t nr_threads = 2;
std::vector<std::thread::id> thread_ids(nr_threads);
std::vector<size_t> thread_ids_user(nr_threads);
std::vector<size_t> thread_ids_worker(nr_threads);
auto runner = [&](size_t i) {
thread_ids_user[i] = std::hash<std::thread::id>{}(std::this_thread::get_id());
std::shared_ptr<Network> network = std::make_shared<Network>(config);
Runtime::set_cpu_inplace_mode(network);
if (multi_thread_compnode) {
......@@ -232,11 +234,18 @@ void test_multi_thread(bool multi_thread_compnode) {
}
network->load_model(model_path);
Runtime::set_runtime_thread_affinity(network, [&thread_ids, i](int id) {
if (id == 0) {
thread_ids[i] = std::this_thread::get_id();
}
});
Runtime::set_runtime_thread_affinity(
network, [&multi_thread_compnode, &thread_ids_worker, i](int id) {
if (multi_thread_compnode) {
if (id == 1) {
thread_ids_worker[i] = std::hash<std::thread::id>{}(
std::this_thread::get_id());
}
} else {
thread_ids_worker[i] = std::hash<std::thread::id>{}(
std::this_thread::get_id());
}
});
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
auto src_ptr = lite_tensor->get_memory_ptr();
......@@ -250,11 +259,11 @@ void test_multi_thread(bool multi_thread_compnode) {
std::vector<std::thread> threads;
for (size_t i = 0; i < nr_threads; i++) {
threads.emplace_back(runner, i);
threads[i].join();
}
for (size_t i = 0; i < nr_threads; i++) {
threads[i].join();
ASSERT_EQ(thread_ids_user[i], thread_ids_worker[i]);
}
ASSERT_NE(thread_ids[0], thread_ids[1]);
}
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册