未验证 提交 323d55a7 编写于 作者: L liutiexing 提交者: GitHub

AddAwaitableTask (#40770)


* AddAwaitableTask for WorkQueue
Co-authored-by: Nliutiexing <liutiexing@google.com>
上级 7e3752bb
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <future>
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -25,6 +28,29 @@ namespace framework { ...@@ -25,6 +28,29 @@ namespace framework {
constexpr const char* kQueueEmptyEvent = "QueueEmpty"; constexpr const char* kQueueEmptyEvent = "QueueEmpty";
constexpr const char* kQueueDestructEvent = "QueueDestruct"; constexpr const char* kQueueDestructEvent = "QueueDestruct";
// For std::function
// https://stackoverflow.com/questions/25421346/how-to-create-an-stdfunction-from-a-move-capturing-lambda-expression
template <typename OnlyMovable>
class FakeCopyable {
public:
explicit FakeCopyable(OnlyMovable&& obj) : obj_(std::move(obj)) {
static_assert(std::is_copy_constructible<OnlyMovable>::value == false,
"Need not to use FakeCopyable");
}
FakeCopyable(FakeCopyable&& other) : obj_(std::move(other.obj_)) {}
FakeCopyable(const FakeCopyable& other) {
PADDLE_THROW(platform::errors::Unavailable(
"Never use the copy constructor of FakeCopyable."));
}
OnlyMovable& Get() { return obj_; }
private:
OnlyMovable obj_;
};
class EventsWaiter; class EventsWaiter;
struct WorkQueueOptions { struct WorkQueueOptions {
...@@ -78,6 +104,22 @@ class WorkQueue { ...@@ -78,6 +104,22 @@ class WorkQueue {
virtual void AddTask(std::function<void()> fn) = 0; virtual void AddTask(std::function<void()> fn) = 0;
// Higher cost than AddTask
template <typename F, typename... Args>
std::future<typename std::result_of<F(Args...)>::type> AddAwaitableTask(
F&& f, Args&&... args) {
using ReturnType = typename std::result_of<F(Args...)>::type;
std::function<ReturnType()> task =
std::bind(std::forward<F>(f), std::forward<Args>(args)...);
std::promise<ReturnType> prom;
std::future<ReturnType> res = prom.get_future();
AddTask([
t = std::move(task),
p = FakeCopyable<std::promise<ReturnType>>(std::move(prom))
]() mutable { p.Get().set_value(t()); });
return res;
}
// See WorkQueueOptions.track_task for details // See WorkQueueOptions.track_task for details
// virtual void WaitQueueEmpty() = 0; // virtual void WaitQueueEmpty() = 0;
...@@ -102,6 +144,22 @@ class WorkQueueGroup { ...@@ -102,6 +144,22 @@ class WorkQueueGroup {
virtual void AddTask(size_t queue_idx, std::function<void()> fn) = 0; virtual void AddTask(size_t queue_idx, std::function<void()> fn) = 0;
// Higher cost than AddTask
template <typename F, typename... Args>
std::future<typename std::result_of<F(Args...)>::type> AddAwaitableTask(
size_t queue_idx, F&& f, Args&&... args) {
using ReturnType = typename std::result_of<F(Args...)>::type;
std::function<ReturnType()> task =
std::bind(std::forward<F>(f), std::forward<Args>(args)...);
std::promise<ReturnType> prom;
std::future<ReturnType> res = prom.get_future();
AddTask(queue_idx, [
t = std::move(task),
p = FakeCopyable<std::promise<ReturnType>>(std::move(prom))
]() mutable { p.Get().set_value(t()); });
return res;
}
// See WorkQueueOptions.track_task for details // See WorkQueueOptions.track_task for details
// virtual void WaitQueueGroupEmpty() = 0; // virtual void WaitQueueGroupEmpty() = 0;
......
...@@ -60,11 +60,13 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { ...@@ -60,11 +60,13 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
} }
finished = true; finished = true;
}); });
auto handle = work_queue->AddAwaitableTask([]() { return 1234; });
// WaitQueueEmpty // WaitQueueEmpty
EXPECT_EQ(finished.load(), false); EXPECT_EQ(finished.load(), false);
events_waiter.WaitEvent(); events_waiter.WaitEvent();
EXPECT_EQ(finished.load(), true); EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum); EXPECT_EQ(counter.load(), kLoopNum);
EXPECT_EQ(handle.get(), 1234);
} }
TEST(WorkQueue, TestMultiThreadedWorkQueue) { TEST(WorkQueue, TestMultiThreadedWorkQueue) {
...@@ -146,6 +148,9 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -146,6 +148,9 @@ TEST(WorkQueue, TestWorkQueueGroup) {
++counter; ++counter;
} }
}); });
int random_num = 123456;
auto handle =
queue_group->AddAwaitableTask(1, [random_num]() { return random_num; });
// WaitQueueGroupEmpty // WaitQueueGroupEmpty
events_waiter.WaitEvent(); events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
...@@ -154,4 +159,5 @@ TEST(WorkQueue, TestWorkQueueGroup) { ...@@ -154,4 +159,5 @@ TEST(WorkQueue, TestWorkQueueGroup) {
events_waiter.WaitEvent(); events_waiter.WaitEvent();
queue_group.reset(); queue_group.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent); EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
EXPECT_EQ(handle.get(), random_num);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册