提交 09b5f3d4 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/core): fix multi thread pool deactive and multi thread conflict

GitOrigin-RevId: 36787a08a5aa8a2e6360f4bd9992039262974797
上级 ef239f83
...@@ -63,13 +63,12 @@ class CpuCompNode::WorkerQueue final ...@@ -63,13 +63,12 @@ class CpuCompNode::WorkerQueue final
#endif #endif
} }
sys::set_thread_name(m_locator.to_string()); sys::set_thread_name(m_locator.to_string());
if(m_thread_pool)
m_thread_pool->active();
} }
void on_sync_all_task_finish() override { void on_sync_all_task_finish() override {
if (m_thread_pool) if (m_thread_pool) {
m_thread_pool->deactive(); m_thread_pool->deactive();
}
} }
public: public:
...@@ -436,6 +435,8 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { ...@@ -436,6 +435,8 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
} }
} }
ThreadPool* get_thread_pool() const { return m_thread_pool.get(); }
void* mgb_aligned_alloc(size_t size) { void* mgb_aligned_alloc(size_t size) {
auto alignment = get_mem_addr_alignment(); auto alignment = get_mem_addr_alignment();
#ifdef WIN32 #ifdef WIN32
...@@ -546,6 +547,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { ...@@ -546,6 +547,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
} else if (m_worker_queue) { } else if (m_worker_queue) {
m_worker_queue->wait_all_task_finish(); m_worker_queue->wait_all_task_finish();
} }
if (m_thread_pool) {
m_thread_pool->deactive();
}
} }
void dispatch(Task &&task) override { void dispatch(Task &&task) override {
...@@ -893,6 +897,11 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() { ...@@ -893,6 +897,11 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() {
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
for (size_t i = 0, it = SCQueueSynchronizer::max_spin() / 20; i < it; ++i) { for (size_t i = 0, it = SCQueueSynchronizer::max_spin() / 20; i < it; ++i) {
if (finished()) { if (finished()) {
auto thread_pool = static_cast<CpuCompNodeImpl*>(m_comp_node_impl)
->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
return; return;
} }
} }
...@@ -906,6 +915,11 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() { ...@@ -906,6 +915,11 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
m_dev_wait_cv.wait(lock); m_dev_wait_cv.wait(lock);
} }
m_dev_wait_nr_waiter.fetch_sub(1, std::memory_order_release); m_dev_wait_nr_waiter.fetch_sub(1, std::memory_order_release);
auto thread_pool =
static_cast<CpuCompNodeImpl*>(m_comp_node_impl)->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
} }
CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept { CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept {
......
...@@ -74,6 +74,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) { ...@@ -74,6 +74,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) {
//! Make sure the main thread have bind //! Make sure the main thread have bind
if (m_main_affinity_flag && if (m_main_affinity_flag &&
m_core_binding_function != nullptr) { m_core_binding_function != nullptr) {
std::lock_guard<std::mutex> lock(m_mutex_task);
m_core_binding_function(m_nr_threads - 1); m_core_binding_function(m_nr_threads - 1);
m_main_affinity_flag = false; m_main_affinity_flag = false;
} }
...@@ -85,10 +86,10 @@ void ThreadPool::add_task(const TaskElem& task_elem) { ...@@ -85,10 +86,10 @@ void ThreadPool::add_task(const TaskElem& task_elem) {
} }
return; return;
} else { } else {
std::lock_guard<std::mutex> lock(m_mutex_task);
mgb_assert(m_task_iter.load(std::memory_order_acquire) <= 0, mgb_assert(m_task_iter.load(std::memory_order_acquire) <= 0,
"The init value of m_all_sub_task is not zero."); "The init value of m_all_sub_task is not zero.");
active(); active();
std::lock_guard<std::mutex> lock(m_mutex_task);
//! Set the task number, task iter and task //! Set the task number, task iter and task
m_nr_parallelism = parallelism; m_nr_parallelism = parallelism;
m_task_iter.exchange(parallelism, std::memory_order_relaxed); m_task_iter.exchange(parallelism, std::memory_order_relaxed);
...@@ -113,6 +114,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) { ...@@ -113,6 +114,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) {
void ThreadPool::set_affinity(AffinityCallBack affinity_cb) { void ThreadPool::set_affinity(AffinityCallBack affinity_cb) {
mgb_assert(affinity_cb, "The affinity callback must not be nullptr"); mgb_assert(affinity_cb, "The affinity callback must not be nullptr");
std::lock_guard<std::mutex> lock(m_mutex_task);
m_core_binding_function = affinity_cb; m_core_binding_function = affinity_cb;
for (size_t i = 0; i < m_nr_threads - 1; i++) { for (size_t i = 0; i < m_nr_threads - 1; i++) {
m_workers[i]->affinity_flag = true; m_workers[i]->affinity_flag = true;
...@@ -147,10 +149,12 @@ void ThreadPool::active() { ...@@ -147,10 +149,12 @@ void ThreadPool::active() {
} }
} }
void ThreadPool::deactive() { void ThreadPool::deactive() {
std::lock_guard<std::mutex> lock_task(m_mutex_task);
std::unique_lock<std::mutex> lock(m_mutex); std::unique_lock<std::mutex> lock(m_mutex);
m_active = false; m_active = false;
} }
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
std::lock_guard<std::mutex> lock_task(m_mutex_task);
{ {
std::unique_lock<std::mutex> lock(m_mutex); std::unique_lock<std::mutex> lock(m_mutex);
m_stop = true; m_stop = true;
......
...@@ -80,7 +80,7 @@ public: ...@@ -80,7 +80,7 @@ public:
~ThreadPool(); ~ThreadPool();
private: private:
size_t m_nr_threads = 0; const size_t m_nr_threads = 0;
//! Indicate whether the main thread have binding //! Indicate whether the main thread have binding
bool m_main_affinity_flag; bool m_main_affinity_flag;
//! The callback binding the threads to cores //! The callback binding the threads to cores
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/system.h" #include "megbrain/system.h"
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"
#include <atomic> #include <atomic>
#include <random> #include <random>
...@@ -59,6 +61,73 @@ TEST(TestThreadPool, BASIC) { ...@@ -59,6 +61,73 @@ TEST(TestThreadPool, BASIC) {
ASSERT_EQ(dst1[i], truth[i]); ASSERT_EQ(dst1[i], truth[i]);
} }
} }
TEST(TestGraph, ParallelRunMultithreadMode) {
// check race conditions when graphs are executed on multple threads
std::atomic_size_t sync_counter{0};
constexpr size_t NR_RUN = 50;
size_t nr_worker = std::max(4, sys::get_cpu_count() / 4);
if (auto setting = MGB_GETENV("TestGraphParallelRun_nr_worker")) {
nr_worker = std::stoul(setting);
}
mgb_log("use %zu workers", nr_worker);
auto sync_barrier = [&sync_counter, nr_worker](size_t& cnt) {
++sync_counter;
++cnt;
while (sync_counter < cnt * nr_worker)
;
};
auto do_worker = [&sync_barrier](size_t sync_cnt) {
auto cn = CompNode::load("multithread2:0");
HostTensorGenerator<> gen;
auto host_x = gen({23}, cn);
HostTensorND host_y, y_expect;
y_expect.copy_from(*host_x);
{
auto py = y_expect.ptr<float>();
for (int i = 0; i < 23; ++i) {
for (int j = 0; j < 5; ++j) {
py[i] = py[i] * 2 + 3;
}
}
}
sync_barrier(sync_cnt);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x), y = x;
for (int i = 0; i < 5; ++i) {
y = y * 2 + 3;
}
sync_barrier(sync_cnt);
auto func = graph->compile({make_callback_copy(y, host_y)});
sync_barrier(sync_cnt);
func->execute();
MGB_ASSERT_TENSOR_EQ(y_expect, host_y);
memset(host_y.raw_ptr(), -1, 23 * sizeof(float));
sync_barrier(sync_cnt);
func->execute();
MGB_ASSERT_TENSOR_EQ(y_expect, host_y);
func->wait();
};
auto worker = [&]() {
size_t scnt = 0;
for (size_t run_id = 0; run_id < NR_RUN; ++run_id) {
do_worker(scnt);
}
};
std::vector<std::thread> workers;
for (size_t i = 0; i < nr_worker; ++i)
workers.emplace_back(worker);
for (auto&& i : workers)
i.join();
}
#else #else
#pragma message "tests are disabled as thread is not enabled." #pragma message "tests are disabled as thread is not enabled."
#endif // MGB_HAVE_THREAD #endif // MGB_HAVE_THREAD
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册