未验证 提交 9db6c762 编写于 作者: L liutiexing 提交者: GitHub

WorkQueue supports always_spinning option (#42029)

* WorkQueue supports always_spinning option

* update

* update
上级 920d44df
......@@ -63,6 +63,7 @@ class AsyncWorkQueue {
group_options.emplace_back(/*name*/ "HostTasks",
/*num_threads*/ host_num_threads,
/*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
......@@ -70,6 +71,7 @@ class AsyncWorkQueue {
group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
/*num_threads*/ deivce_num_threads,
/*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
......@@ -77,6 +79,7 @@ class AsyncWorkQueue {
group_options.emplace_back(/*name*/ "Prepare",
/*num_threads*/ 1,
/*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false,
/*detached*/ true,
/*events_waiter*/ waiter);
......
......@@ -29,13 +29,13 @@ class ThreadPoolTempl {
typedef RunQueue<Task, 1024> Queue;
ThreadPoolTempl(const std::string& name, int num_threads, bool allow_spinning,
Environment env = Environment())
bool always_spinning, Environment env = Environment())
: env_(env),
allow_spinning_(allow_spinning),
always_spinning_(always_spinning),
global_steal_partition_(EncodePartition(0, num_threads_)),
blocked_(0),
num_tasks_(0),
spinning_(0),
done_(false),
cancelled_(false),
ec_(num_threads),
......@@ -236,11 +236,11 @@ class ThreadPoolTempl {
Environment env_;
const bool allow_spinning_;
const bool always_spinning_;
std::vector<std::vector<unsigned>> all_coprimes_;
unsigned global_steal_partition_;
std::atomic<unsigned> blocked_;
std::atomic<uint64_t> num_tasks_;
std::atomic<bool> spinning_;
std::atomic<bool> done_;
std::atomic<bool> cancelled_;
EventCount ec_;
......@@ -417,6 +417,15 @@ class ThreadPoolTempl {
ec_.Notify(true);
return false;
}
// Cancel wait if always_spinning_
if (always_spinning_) {
ec_.CancelWait();
blocked_--;
return true;
}
// Wait for work
platform::RecordEvent record("WaitForWork",
platform::TracerEventType::UserDefined, 10);
ec_.CommitWait(waiter);
......
......@@ -21,6 +21,10 @@ void WorkQueueOptions::Validate() const {
name.find('_'), std::string::npos,
platform::errors::InvalidArgument(
"WorkQueueOptions.name shouldn't contain an underline"));
PADDLE_ENFORCE_EQ(
allow_spinning == false && always_spinning == true, false,
platform::errors::InvalidArgument("WorkQueueOptions.allow_spinning must "
"be true when always_spinning is set"));
}
namespace {
......@@ -40,7 +44,8 @@ class WorkQueueImpl : public WorkQueue {
options.events_waiter->RegisterEvent(kQueueDestructEvent);
}
queue_ = new NonblockingThreadPool(options_.name, options_.num_threads,
options_.allow_spinning);
options_.allow_spinning,
options_.always_spinning);
}
virtual ~WorkQueueImpl() {
......@@ -127,8 +132,9 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent);
}
queues_[idx] = new (&queues_storage_[idx]) NonblockingThreadPool(
options.name, options.num_threads, options.allow_spinning);
queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.name, options.num_threads,
options.allow_spinning, options.always_spinning);
}
}
......
......@@ -64,11 +64,12 @@ struct WorkQueueOptions {
}
WorkQueueOptions(const std::string& name, size_t num_threads,
bool allow_spinning, bool track_task, bool detached,
EventsWaiter* waiter)
bool allow_spinning, bool always_spinning, bool track_task,
bool detached, EventsWaiter* waiter)
: name(name),
num_threads(num_threads),
allow_spinning(allow_spinning),
always_spinning(always_spinning),
track_task(track_task),
detached(detached),
events_waiter(waiter) {
......@@ -80,7 +81,11 @@ struct WorkQueueOptions {
std::string name;
size_t num_threads;
// Worker threads will spin for a while if this flag is set.
bool allow_spinning;
// Worker threads will never sleep if this flag is set.
// Better performance vs. higher CPU utilization.
bool always_spinning{false};
// If you need to blocking the calling thread to wait "queue empty", set
// track_task = true and set events_waiter. EventsWaiter::WaitEvent will
// block the calling thread until any of events (including "queue empty")
......
......@@ -48,6 +48,7 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
EventsWaiter events_waiter;
WorkQueueOptions options(/*name*/ "SingleThreadedWorkQueueForTesting",
/*num_threads*/ 1, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ true,
&events_waiter);
auto work_queue = CreateSingleThreadedWorkQueue(options);
......@@ -69,6 +70,15 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum);
EXPECT_EQ(handle.get(), 1234);
work_queue.reset();
// Test default_options with no spinning
WorkQueueOptions default_options("SingleThreadedWorkQueueForTesting",
/*num_threads*/ 1,
/*allow_spinning*/ false,
/*track_task*/ false);
work_queue = CreateSingleThreadedWorkQueue(default_options);
handle = work_queue->AddAwaitableTask([]() { return 5678; });
EXPECT_EQ(handle.get(), 5678);
}
TEST(WorkQueue, TestMultiThreadedWorkQueue) {
......@@ -85,6 +95,7 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
EventsWaiter events_waiter;
WorkQueueOptions options(/*name*/ "MultiThreadedWorkQueueForTesting",
/*num_threads*/ 10, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ false,
&events_waiter);
auto work_queue = CreateMultiThreadedWorkQueue(options);
......@@ -115,6 +126,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
});
work_queue.reset();
waiter_thread.join();
// Forever spin unittest
WorkQueueOptions default_options("MultiThreadedWorkQueueForTesting",
/*num_threads*/ 10, /*allow_spinning*/ false,
/*track_task*/ false);
work_queue = CreateMultiThreadedWorkQueue(default_options);
auto handle = work_queue->AddAwaitableTask([]() { return 5678; });
EXPECT_EQ(handle.get(), 5678);
}
TEST(WorkQueue, TestWorkQueueGroup) {
......@@ -130,10 +148,12 @@ TEST(WorkQueue, TestWorkQueueGroup) {
EventsWaiter events_waiter;
WorkQueueOptions sq_options(/*name*/ "SingleThreadedWorkQueueForTesting",
/*num_threads*/ 1, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ false,
&events_waiter);
WorkQueueOptions mq_options(/*name*/ "MultiThreadedWorkQueueForTesting",
/*num_threads*/ 10, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ false,
&events_waiter);
auto queue_group = CreateWorkQueueGroup({sq_options, mq_options});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册