未验证 提交 f581f5bf 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fix bug that no thread is waked up when adding task to threadpool (#41567)

* fix bug that no thread is waked up when adding task to threadpool

* fix typo
上级 b3e79731
......@@ -39,6 +39,7 @@ constexpr size_t kPrepareWorkQueueIdx = 2;
void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
std::function<void()> fn) {
VLOG(4) << "Add task: " << static_cast<size_t>(op_func_type) << " ";
// NOTE(zhiqiu): use thhe second queue of size of, so only one thread is used.
if (FLAGS_new_executor_sequential_run) {
VLOG(4) << "FLAGS_new_executor_sequential_run:"
......
......@@ -54,6 +54,7 @@
#include <cstdlib>
#include <mutex>
#include <vector>
#include "glog/logging.h"
namespace paddle {
namespace framework {
......@@ -255,6 +256,7 @@ class EventCount {
std::unique_lock<std::mutex> lock(w->mu);
while (w->state != Waiter::kSignaled) {
w->state = Waiter::kWaiting;
VLOG(10) << "Go to wait " << &(w->cv);
w->cv.wait(lock);
}
}
......@@ -270,7 +272,10 @@ class EventCount {
w->state = Waiter::kSignaled;
}
// Avoid notifying if it wasn't waiting.
if (state == Waiter::kWaiting) w->cv.notify_one();
if (state == Waiter::kWaiting) {
VLOG(10) << "Go to notify " << &(w->cv);
w->cv.notify_one();
}
}
}
};
......
......@@ -53,7 +53,6 @@ class ThreadPoolTempl {
all_coprimes_.reserve(num_threads_);
for (int i = 1; i <= num_threads_; ++i) {
all_coprimes_.emplace_back();
all_coprimes_.back().push_back(i);
ComputeCoprimes(i, &(all_coprimes_.back()));
}
for (int i = 0; i < num_threads_; i++) {
......@@ -130,8 +129,11 @@ class ThreadPoolTempl {
// this. We expect that such scenario is prevented by program, that is,
// this is kept alive while any threads can potentially be in Schedule.
if (!t.f) {
if (num_tasks > num_threads_ - blocked_.load(std::memory_order_relaxed)) {
if (num_tasks > num_threads_ - blocked_) {
VLOG(6) << "Add task, Notify";
ec_.Notify(false);
} else {
VLOG(6) << "Add task, No Notify";
}
} else {
num_tasks_.fetch_sub(1, std::memory_order_relaxed);
......@@ -376,17 +378,21 @@ class ThreadPoolTempl {
ec_.CancelWait();
return false;
}
// Number of blocked threads is used as termination condition.
// If we are shutting down and all worker threads blocked without work,
// that's we are done.
blocked_++;
// Now do a reliable emptiness check.
int victim = NonEmptyQueueIndex();
if (victim != -1) {
ec_.CancelWait();
*t = thread_data_[victim].queue.PopBack();
blocked_--;
return true;
}
// Number of blocked threads is used as termination condition.
// If we are shutting down and all worker threads blocked without work,
// that's we are done.
blocked_++;
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
ec_.CancelWait();
// Almost done, but need to re-check queues.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册