未验证 提交 9db6c762 编写于 作者: L liutiexing 提交者: GitHub

WorkQueue supports always_spinning option (#42029)

* WorkQueue supports always_spinning option

* update

* update
上级 920d44df
...@@ -63,6 +63,7 @@ class AsyncWorkQueue { ...@@ -63,6 +63,7 @@ class AsyncWorkQueue {
group_options.emplace_back(/*name*/ "HostTasks", group_options.emplace_back(/*name*/ "HostTasks",
/*num_threads*/ host_num_threads, /*num_threads*/ host_num_threads,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false, /*track_task*/ false,
/*detached*/ true, /*detached*/ true,
/*events_waiter*/ waiter); /*events_waiter*/ waiter);
...@@ -70,6 +71,7 @@ class AsyncWorkQueue { ...@@ -70,6 +71,7 @@ class AsyncWorkQueue {
group_options.emplace_back(/*name*/ "DeviceKernelLaunch", group_options.emplace_back(/*name*/ "DeviceKernelLaunch",
/*num_threads*/ deivce_num_threads, /*num_threads*/ deivce_num_threads,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ false, /*track_task*/ false,
/*detached*/ true, /*detached*/ true,
/*events_waiter*/ waiter); /*events_waiter*/ waiter);
...@@ -77,6 +79,7 @@ class AsyncWorkQueue { ...@@ -77,6 +79,7 @@ class AsyncWorkQueue {
group_options.emplace_back(/*name*/ "Prepare", group_options.emplace_back(/*name*/ "Prepare",
/*num_threads*/ 1, /*num_threads*/ 1,
/*allow_spinning*/ true, /*allow_spinning*/ true,
/*always_spinning*/ false,
/*track_task*/ false, /*track_task*/ false,
/*detached*/ true, /*detached*/ true,
/*events_waiter*/ waiter); /*events_waiter*/ waiter);
......
...@@ -29,13 +29,13 @@ class ThreadPoolTempl { ...@@ -29,13 +29,13 @@ class ThreadPoolTempl {
typedef RunQueue<Task, 1024> Queue; typedef RunQueue<Task, 1024> Queue;
ThreadPoolTempl(const std::string& name, int num_threads, bool allow_spinning, ThreadPoolTempl(const std::string& name, int num_threads, bool allow_spinning,
Environment env = Environment()) bool always_spinning, Environment env = Environment())
: env_(env), : env_(env),
allow_spinning_(allow_spinning), allow_spinning_(allow_spinning),
always_spinning_(always_spinning),
global_steal_partition_(EncodePartition(0, num_threads_)), global_steal_partition_(EncodePartition(0, num_threads_)),
blocked_(0), blocked_(0),
num_tasks_(0), num_tasks_(0),
spinning_(0),
done_(false), done_(false),
cancelled_(false), cancelled_(false),
ec_(num_threads), ec_(num_threads),
...@@ -236,11 +236,11 @@ class ThreadPoolTempl { ...@@ -236,11 +236,11 @@ class ThreadPoolTempl {
Environment env_; Environment env_;
const bool allow_spinning_; const bool allow_spinning_;
const bool always_spinning_;
std::vector<std::vector<unsigned>> all_coprimes_; std::vector<std::vector<unsigned>> all_coprimes_;
unsigned global_steal_partition_; unsigned global_steal_partition_;
std::atomic<unsigned> blocked_; std::atomic<unsigned> blocked_;
std::atomic<uint64_t> num_tasks_; std::atomic<uint64_t> num_tasks_;
std::atomic<bool> spinning_;
std::atomic<bool> done_; std::atomic<bool> done_;
std::atomic<bool> cancelled_; std::atomic<bool> cancelled_;
EventCount ec_; EventCount ec_;
...@@ -417,6 +417,15 @@ class ThreadPoolTempl { ...@@ -417,6 +417,15 @@ class ThreadPoolTempl {
ec_.Notify(true); ec_.Notify(true);
return false; return false;
} }
// Cancel wait if always_spinning_
if (always_spinning_) {
ec_.CancelWait();
blocked_--;
return true;
}
// Wait for work
platform::RecordEvent record("WaitForWork", platform::RecordEvent record("WaitForWork",
platform::TracerEventType::UserDefined, 10); platform::TracerEventType::UserDefined, 10);
ec_.CommitWait(waiter); ec_.CommitWait(waiter);
......
...@@ -21,6 +21,10 @@ void WorkQueueOptions::Validate() const { ...@@ -21,6 +21,10 @@ void WorkQueueOptions::Validate() const {
name.find('_'), std::string::npos, name.find('_'), std::string::npos,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"WorkQueueOptions.name shouldn't contain an underline")); "WorkQueueOptions.name shouldn't contain an underline"));
PADDLE_ENFORCE_EQ(
allow_spinning == false && always_spinning == true, false,
platform::errors::InvalidArgument("WorkQueueOptions.allow_spinning must "
"be true when always_spinning is set"));
} }
namespace { namespace {
...@@ -40,7 +44,8 @@ class WorkQueueImpl : public WorkQueue { ...@@ -40,7 +44,8 @@ class WorkQueueImpl : public WorkQueue {
options.events_waiter->RegisterEvent(kQueueDestructEvent); options.events_waiter->RegisterEvent(kQueueDestructEvent);
} }
queue_ = new NonblockingThreadPool(options_.name, options_.num_threads, queue_ = new NonblockingThreadPool(options_.name, options_.num_threads,
options_.allow_spinning); options_.allow_spinning,
options_.always_spinning);
} }
virtual ~WorkQueueImpl() { virtual ~WorkQueueImpl() {
...@@ -127,8 +132,9 @@ WorkQueueGroupImpl::WorkQueueGroupImpl( ...@@ -127,8 +132,9 @@ WorkQueueGroupImpl::WorkQueueGroupImpl(
destruct_notifier_ = destruct_notifier_ =
options.events_waiter->RegisterEvent(kQueueDestructEvent); options.events_waiter->RegisterEvent(kQueueDestructEvent);
} }
queues_[idx] = new (&queues_storage_[idx]) NonblockingThreadPool( queues_[idx] = new (&queues_storage_[idx])
options.name, options.num_threads, options.allow_spinning); NonblockingThreadPool(options.name, options.num_threads,
options.allow_spinning, options.always_spinning);
} }
} }
......
...@@ -64,11 +64,12 @@ struct WorkQueueOptions { ...@@ -64,11 +64,12 @@ struct WorkQueueOptions {
} }
WorkQueueOptions(const std::string& name, size_t num_threads, WorkQueueOptions(const std::string& name, size_t num_threads,
bool allow_spinning, bool track_task, bool detached, bool allow_spinning, bool always_spinning, bool track_task,
EventsWaiter* waiter) bool detached, EventsWaiter* waiter)
: name(name), : name(name),
num_threads(num_threads), num_threads(num_threads),
allow_spinning(allow_spinning), allow_spinning(allow_spinning),
always_spinning(always_spinning),
track_task(track_task), track_task(track_task),
detached(detached), detached(detached),
events_waiter(waiter) { events_waiter(waiter) {
...@@ -80,7 +81,11 @@ struct WorkQueueOptions { ...@@ -80,7 +81,11 @@ struct WorkQueueOptions {
std::string name; std::string name;
size_t num_threads; size_t num_threads;
// Worker threads will spin for a while if this flag is set.
bool allow_spinning; bool allow_spinning;
// Worker threads will never sleep if this flag is set.
// Better performance vs. higher CPU utilization.
bool always_spinning{false};
// If you need to blocking the calling thread to wait "queue empty", set // If you need to blocking the calling thread to wait "queue empty", set
// track_task = true and set events_waiter. EventsWaiter::WaitEvent will // track_task = true and set events_waiter. EventsWaiter::WaitEvent will
// block the calling thread until any of events (including "queue empty") // block the calling thread until any of events (including "queue empty")
......
...@@ -48,6 +48,7 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { ...@@ -48,6 +48,7 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
EventsWaiter events_waiter; EventsWaiter events_waiter;
WorkQueueOptions options(/*name*/ "SingleThreadedWorkQueueForTesting", WorkQueueOptions options(/*name*/ "SingleThreadedWorkQueueForTesting",
/*num_threads*/ 1, /*allow_spinning*/ true, /*num_threads*/ 1, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ true, /*track_task*/ true, /*detached*/ true,
&events_waiter); &events_waiter);
auto work_queue = CreateSingleThreadedWorkQueue(options); auto work_queue = CreateSingleThreadedWorkQueue(options);
...@@ -69,6 +70,15 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { ...@@ -69,6 +70,15 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
EXPECT_EQ(finished.load(), true); EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum); EXPECT_EQ(counter.load(), kLoopNum);
EXPECT_EQ(handle.get(), 1234); EXPECT_EQ(handle.get(), 1234);
work_queue.reset();
// Test default_options with no spinning
WorkQueueOptions default_options("SingleThreadedWorkQueueForTesting",
/*num_threads*/ 1,
/*allow_spinning*/ false,
/*track_task*/ false);
work_queue = CreateSingleThreadedWorkQueue(default_options);
handle = work_queue->AddAwaitableTask([]() { return 5678; });
EXPECT_EQ(handle.get(), 5678);
} }
TEST(WorkQueue, TestMultiThreadedWorkQueue) { TEST(WorkQueue, TestMultiThreadedWorkQueue) {
...@@ -85,6 +95,7 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -85,6 +95,7 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
EventsWaiter events_waiter; EventsWaiter events_waiter;
WorkQueueOptions options(/*name*/ "MultiThreadedWorkQueueForTesting", WorkQueueOptions options(/*name*/ "MultiThreadedWorkQueueForTesting",
/*num_threads*/ 10, /*allow_spinning*/ true, /*num_threads*/ 10, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ false, /*track_task*/ true, /*detached*/ false,
&events_waiter); &events_waiter);
auto work_queue = CreateMultiThreadedWorkQueue(options); auto work_queue = CreateMultiThreadedWorkQueue(options);
...@@ -115,6 +126,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -115,6 +126,13 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
}); });
work_queue.reset(); work_queue.reset();
waiter_thread.join(); waiter_thread.join();
// Forever spin unittest
WorkQueueOptions default_options("MultiThreadedWorkQueueForTesting",
/*num_threads*/ 10, /*allow_spinning*/ false,
/*track_task*/ false);
work_queue = CreateMultiThreadedWorkQueue(default_options);
auto handle = work_queue->AddAwaitableTask([]() { return 5678; });
EXPECT_EQ(handle.get(), 5678);
} }
TEST(WorkQueue, TestWorkQueueGroup) { TEST(WorkQueue, TestWorkQueueGroup) {
...@@ -130,10 +148,12 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -130,10 +148,12 @@ TEST(WorkQueue, TestWorkQueueGroup) {
EventsWaiter events_waiter; EventsWaiter events_waiter;
WorkQueueOptions sq_options(/*name*/ "SingleThreadedWorkQueueForTesting", WorkQueueOptions sq_options(/*name*/ "SingleThreadedWorkQueueForTesting",
/*num_threads*/ 1, /*allow_spinning*/ true, /*num_threads*/ 1, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ false, /*track_task*/ true, /*detached*/ false,
&events_waiter); &events_waiter);
WorkQueueOptions mq_options(/*name*/ "MultiThreadedWorkQueueForTesting", WorkQueueOptions mq_options(/*name*/ "MultiThreadedWorkQueueForTesting",
/*num_threads*/ 10, /*allow_spinning*/ true, /*num_threads*/ 10, /*allow_spinning*/ true,
/*always_spinning*/ true,
/*track_task*/ true, /*detached*/ false, /*track_task*/ true, /*detached*/ false,
&events_waiter); &events_waiter);
auto queue_group = CreateWorkQueueGroup({sq_options, mq_options}); auto queue_group = CreateWorkQueueGroup({sq_options, mq_options});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册