提交 09b5f3d4 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/core): fix multi thread pool deactive and multi thread conflict

GitOrigin-RevId: 36787a08a5aa8a2e6360f4bd9992039262974797
上级 ef239f83
......@@ -63,13 +63,12 @@ class CpuCompNode::WorkerQueue final
#endif
}
sys::set_thread_name(m_locator.to_string());
if(m_thread_pool)
m_thread_pool->active();
}
void on_sync_all_task_finish() override {
if (m_thread_pool)
if (m_thread_pool) {
m_thread_pool->deactive();
}
}
public:
......@@ -436,6 +435,8 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
}
}
ThreadPool* get_thread_pool() const { return m_thread_pool.get(); }
void* mgb_aligned_alloc(size_t size) {
auto alignment = get_mem_addr_alignment();
#ifdef WIN32
......@@ -546,6 +547,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
} else if (m_worker_queue) {
m_worker_queue->wait_all_task_finish();
}
if (m_thread_pool) {
m_thread_pool->deactive();
}
}
void dispatch(Task &&task) override {
......@@ -893,6 +897,11 @@ bool CpuCompNode::CpuDispatchableBase::EventImpl::do_finished() {
void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
for (size_t i = 0, it = SCQueueSynchronizer::max_spin() / 20; i < it; ++i) {
if (finished()) {
auto thread_pool = static_cast<CpuCompNodeImpl*>(m_comp_node_impl)
->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
return;
}
}
......@@ -906,6 +915,11 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::host_wait_cv() {
m_dev_wait_cv.wait(lock);
}
m_dev_wait_nr_waiter.fetch_sub(1, std::memory_order_release);
auto thread_pool =
static_cast<CpuCompNodeImpl*>(m_comp_node_impl)->get_thread_pool();
if (thread_pool) {
thread_pool->deactive();
}
}
CpuCompNode::CpuDispatchableBase::EventImpl::~EventImpl() noexcept {
......
......@@ -74,6 +74,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) {
//! Make sure the main thread have bind
if (m_main_affinity_flag &&
m_core_binding_function != nullptr) {
std::lock_guard<std::mutex> lock(m_mutex_task);
m_core_binding_function(m_nr_threads - 1);
m_main_affinity_flag = false;
}
......@@ -85,10 +86,10 @@ void ThreadPool::add_task(const TaskElem& task_elem) {
}
return;
} else {
std::lock_guard<std::mutex> lock(m_mutex_task);
mgb_assert(m_task_iter.load(std::memory_order_acquire) <= 0,
"The init value of m_all_sub_task is not zero.");
active();
std::lock_guard<std::mutex> lock(m_mutex_task);
//! Set the task number, task iter and task
m_nr_parallelism = parallelism;
m_task_iter.exchange(parallelism, std::memory_order_relaxed);
......@@ -113,6 +114,7 @@ void ThreadPool::add_task(const TaskElem& task_elem) {
void ThreadPool::set_affinity(AffinityCallBack affinity_cb) {
mgb_assert(affinity_cb, "The affinity callback must not be nullptr");
std::lock_guard<std::mutex> lock(m_mutex_task);
m_core_binding_function = affinity_cb;
for (size_t i = 0; i < m_nr_threads - 1; i++) {
m_workers[i]->affinity_flag = true;
......@@ -147,10 +149,12 @@ void ThreadPool::active() {
}
}
void ThreadPool::deactive() {
std::lock_guard<std::mutex> lock_task(m_mutex_task);
std::unique_lock<std::mutex> lock(m_mutex);
m_active = false;
}
ThreadPool::~ThreadPool() {
std::lock_guard<std::mutex> lock_task(m_mutex_task);
{
std::unique_lock<std::mutex> lock(m_mutex);
m_stop = true;
......
......@@ -80,7 +80,7 @@ public:
~ThreadPool();
private:
size_t m_nr_threads = 0;
const size_t m_nr_threads = 0;
//! Indicate whether the main thread have binding
bool m_main_affinity_flag;
//! The callback binding the threads to cores
......
......@@ -12,6 +12,8 @@
#include "megbrain/comp_node.h"
#include "megbrain/system.h"
#include "megbrain/test/helper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"
#include <atomic>
#include <random>
......@@ -59,6 +61,73 @@ TEST(TestThreadPool, BASIC) {
ASSERT_EQ(dst1[i], truth[i]);
}
}
TEST(TestGraph, ParallelRunMultithreadMode) {
// check race conditions when graphs are executed on multple threads
std::atomic_size_t sync_counter{0};
constexpr size_t NR_RUN = 50;
size_t nr_worker = std::max(4, sys::get_cpu_count() / 4);
if (auto setting = MGB_GETENV("TestGraphParallelRun_nr_worker")) {
nr_worker = std::stoul(setting);
}
mgb_log("use %zu workers", nr_worker);
auto sync_barrier = [&sync_counter, nr_worker](size_t& cnt) {
++sync_counter;
++cnt;
while (sync_counter < cnt * nr_worker)
;
};
auto do_worker = [&sync_barrier](size_t sync_cnt) {
auto cn = CompNode::load("multithread2:0");
HostTensorGenerator<> gen;
auto host_x = gen({23}, cn);
HostTensorND host_y, y_expect;
y_expect.copy_from(*host_x);
{
auto py = y_expect.ptr<float>();
for (int i = 0; i < 23; ++i) {
for (int j = 0; j < 5; ++j) {
py[i] = py[i] * 2 + 3;
}
}
}
sync_barrier(sync_cnt);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x), y = x;
for (int i = 0; i < 5; ++i) {
y = y * 2 + 3;
}
sync_barrier(sync_cnt);
auto func = graph->compile({make_callback_copy(y, host_y)});
sync_barrier(sync_cnt);
func->execute();
MGB_ASSERT_TENSOR_EQ(y_expect, host_y);
memset(host_y.raw_ptr(), -1, 23 * sizeof(float));
sync_barrier(sync_cnt);
func->execute();
MGB_ASSERT_TENSOR_EQ(y_expect, host_y);
func->wait();
};
auto worker = [&]() {
size_t scnt = 0;
for (size_t run_id = 0; run_id < NR_RUN; ++run_id) {
do_worker(scnt);
}
};
std::vector<std::thread> workers;
for (size_t i = 0; i < nr_worker; ++i)
workers.emplace_back(worker);
for (auto&& i : workers)
i.join();
}
#else
#pragma message "tests are disabled as thread is not enabled."
#endif // MGB_HAVE_THREAD
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册