From 323d55a7badd1ab7ec6a91cd6739a6c0924f87b7 Mon Sep 17 00:00:00 2001 From: liutiexing <74819124+liutiexing@users.noreply.github.com> Date: Wed, 23 Mar 2022 15:03:31 +0800 Subject: [PATCH] AddAwaitableTask (#40770) * AddAwaitableTask for WorkQueue Co-authored-by: liutiexing --- .../new_executor/workqueue/workqueue.h | 58 +++++++++++++++++++ .../new_executor/workqueue/workqueue_test.cc | 6 ++ 2 files changed, 64 insertions(+) diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue.h b/paddle/fluid/framework/new_executor/workqueue/workqueue.h index 6c8abee2f01..0101461658d 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue.h +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue.h @@ -15,9 +15,12 @@ #pragma once #include +#include #include #include +#include #include +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { @@ -25,6 +28,29 @@ namespace framework { constexpr const char* kQueueEmptyEvent = "QueueEmpty"; constexpr const char* kQueueDestructEvent = "QueueDestruct"; +// For std::function +// https://stackoverflow.com/questions/25421346/how-to-create-an-stdfunction-from-a-move-capturing-lambda-expression +template +class FakeCopyable { + public: + explicit FakeCopyable(OnlyMovable&& obj) : obj_(std::move(obj)) { + static_assert(std::is_copy_constructible::value == false, + "Need not to use FakeCopyable"); + } + + FakeCopyable(FakeCopyable&& other) : obj_(std::move(other.obj_)) {} + + FakeCopyable(const FakeCopyable& other) { + PADDLE_THROW(platform::errors::Unavailable( + "Never use the copy constructor of FakeCopyable.")); + } + + OnlyMovable& Get() { return obj_; } + + private: + OnlyMovable obj_; +}; + class EventsWaiter; struct WorkQueueOptions { @@ -78,6 +104,22 @@ class WorkQueue { virtual void AddTask(std::function fn) = 0; + // Higher cost than AddTask + template + std::future::type> AddAwaitableTask( + F&& f, Args&&... args) { + using ReturnType = typename std::result_of::type; + std::function task = + std::bind(std::forward(f), std::forward(args)...); + std::promise prom; + std::future res = prom.get_future(); + AddTask([ + t = std::move(task), + p = FakeCopyable>(std::move(prom)) + ]() mutable { p.Get().set_value(t()); }); + return res; + } + // See WorkQueueOptions.track_task for details // virtual void WaitQueueEmpty() = 0; @@ -102,6 +144,22 @@ class WorkQueueGroup { virtual void AddTask(size_t queue_idx, std::function fn) = 0; + // Higher cost than AddTask + template + std::future::type> AddAwaitableTask( + size_t queue_idx, F&& f, Args&&... args) { + using ReturnType = typename std::result_of::type; + std::function task = + std::bind(std::forward(f), std::forward(args)...); + std::promise prom; + std::future res = prom.get_future(); + AddTask(queue_idx, [ + t = std::move(task), + p = FakeCopyable>(std::move(prom)) + ]() mutable { p.Get().set_value(t()); }); + return res; + } + // See WorkQueueOptions.track_task for details // virtual void WaitQueueGroupEmpty() = 0; diff --git a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc index 25448da8f10..97f0282a158 100644 --- a/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc +++ b/paddle/fluid/framework/new_executor/workqueue/workqueue_test.cc @@ -60,11 +60,13 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { } finished = true; }); + auto handle = work_queue->AddAwaitableTask([]() { return 1234; }); // WaitQueueEmpty EXPECT_EQ(finished.load(), false); events_waiter.WaitEvent(); EXPECT_EQ(finished.load(), true); EXPECT_EQ(counter.load(), kLoopNum); + EXPECT_EQ(handle.get(), 1234); } TEST(WorkQueue, TestMultiThreadedWorkQueue) { @@ -146,6 +148,9 @@ TEST(WorkQueue, TestWorkQueueGroup) { ++counter; } }); + int random_num = 123456; + auto handle = + queue_group->AddAwaitableTask(1, [random_num]() { return random_num; }); // WaitQueueGroupEmpty events_waiter.WaitEvent(); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); @@ -154,4 +159,5 @@ TEST(WorkQueue, TestWorkQueueGroup) { events_waiter.WaitEvent(); queue_group.reset(); EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); + EXPECT_EQ(handle.get(), random_num); } -- GitLab