未验证 提交 f581f5bf 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fix bug that no thread is waked up when adding task to threadpool (#41567)

* fix bug that no thread is waked up when adding task to threadpool

* fix typo
上级 b3e79731
...@@ -39,6 +39,7 @@ constexpr size_t kPrepareWorkQueueIdx = 2; ...@@ -39,6 +39,7 @@ constexpr size_t kPrepareWorkQueueIdx = 2;
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
std::function<void()> fn) { std::function<void()> fn) {
VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
// NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used. // NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
if (FLAGS_new_executor_sequential_run) { if (FLAGS_new_executor_sequential_run) {
VLOG(4) << "FLAGS_new_executor_sequential_run:" VLOG(4) << "FLAGS_new_executor_sequential_run:"
......
...@@ -54,6 +54,7 @@ ...@@ -54,6 +54,7 @@
#include <cstdlib> #include <cstdlib>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -255,6 +256,7 @@ class EventCount { ...@@ -255,6 +256,7 @@ class EventCount {
std::unique_lock<std::mutex> lock(w->mu); std::unique_lock<std::mutex> lock(w->mu);
while (w->state != Waiter::kSignaled) { while (w->state != Waiter::kSignaled) {
w->state = Waiter::kWaiting; w->state = Waiter::kWaiting;
VLOG(10) << "Go to wait " << &(w->cv);
w->cv.wait(lock); w->cv.wait(lock);
} }
} }
...@@ -270,7 +272,10 @@ class EventCount { ...@@ -270,7 +272,10 @@ class EventCount {
w->state = Waiter::kSignaled; w->state = Waiter::kSignaled;
} }
// Avoid notifying if it wasn't waiting. // Avoid notifying if it wasn't waiting.
if (state == Waiter::kWaiting) w->cv.notify_one(); if (state == Waiter::kWaiting) {
VLOG(10) << "Go to notify " << &(w->cv);
w->cv.notify_one();
}
} }
} }
}; };
......
...@@ -53,7 +53,6 @@ class ThreadPoolTempl { ...@@ -53,7 +53,6 @@ class ThreadPoolTempl {
all_coprimes_.reserve(num_threads_); all_coprimes_.reserve(num_threads_);
for (int i = 1; i <= num_threads_; ++i) { for (int i = 1; i <= num_threads_; ++i) {
all_coprimes_.emplace_back(); all_coprimes_.emplace_back();
all_coprimes_.back().push_back(i);
ComputeCoprimes(i, &(all_coprimes_.back())); ComputeCoprimes(i, &(all_coprimes_.back()));
} }
for (int i = 0; i < num_threads_; i++) { for (int i = 0; i < num_threads_; i++) {
...@@ -130,8 +129,11 @@ class ThreadPoolTempl { ...@@ -130,8 +129,11 @@ class ThreadPoolTempl {
// this. We expect that such scenario is prevented by program, that is, // this. We expect that such scenario is prevented by program, that is,
// this is kept alive while any threads can potentially be in Schedule. // this is kept alive while any threads can potentially be in Schedule.
if (!t.f) { if (!t.f) {
if (num_tasks > num_threads_ - blocked_.load(std::memory_order_relaxed)) { if (num_tasks > num_threads_ - blocked_) {
VLOG(6) << "Add task, Notify";
ec_.Notify(false); ec_.Notify(false);
} else {
VLOG(6) << "Add task, No Notify";
} }
} else { } else {
num_tasks_.fetch_sub(1, std::memory_order_relaxed); num_tasks_.fetch_sub(1, std::memory_order_relaxed);
...@@ -376,17 +378,21 @@ class ThreadPoolTempl { ...@@ -376,17 +378,21 @@ class ThreadPoolTempl {
ec_.CancelWait(); ec_.CancelWait();
return false; return false;
} }
// Number of blocked threads is used as termination condition.
// If we are shutting down and all worker threads blocked without work,
// that's we are done.
blocked_++;
// Now do a reliable emptiness check. // Now do a reliable emptiness check.
int victim = NonEmptyQueueIndex(); int victim = NonEmptyQueueIndex();
if (victim != -1) { if (victim != -1) {
ec_.CancelWait(); ec_.CancelWait();
*t = thread_data_[victim].queue.PopBack(); *t = thread_data_[victim].queue.PopBack();
blocked_--;
return true; return true;
} }
// Number of blocked threads is used as termination condition.
// If we are shutting down and all worker threads blocked without work,
// that's we are done.
blocked_++;
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) { if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
ec_.CancelWait(); ec_.CancelWait();
// Almost done, but need to re-check queues. // Almost done, but need to re-check queues.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册