未验证 提交 323d55a7 编写于 作者: L liutiexing 提交者: GitHub

AddAwaitableTask (#40770)


* AddAwaitableTask for WorkQueue
Co-authored-by: Nliutiexing <liutiexing@google.com>
上级 7e3752bb
......@@ -15,9 +15,12 @@
#pragma once
#include <functional>
#include <future>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
......@@ -25,6 +28,29 @@ namespace framework {
constexpr const char* kQueueEmptyEvent = "QueueEmpty";
constexpr const char* kQueueDestructEvent = "QueueDestruct";
// For std::function
// https://stackoverflow.com/questions/25421346/how-to-create-an-stdfunction-from-a-move-capturing-lambda-expression
template <typename OnlyMovable>
class FakeCopyable {
public:
explicit FakeCopyable(OnlyMovable&& obj) : obj_(std::move(obj)) {
static_assert(std::is_copy_constructible<OnlyMovable>::value == false,
"Need not to use FakeCopyable");
}
FakeCopyable(FakeCopyable&& other) : obj_(std::move(other.obj_)) {}
FakeCopyable(const FakeCopyable& other) {
PADDLE_THROW(platform::errors::Unavailable(
"Never use the copy constructor of FakeCopyable."));
}
OnlyMovable& Get() { return obj_; }
private:
OnlyMovable obj_;
};
class EventsWaiter;
struct WorkQueueOptions {
......@@ -78,6 +104,22 @@ class WorkQueue {
virtual void AddTask(std::function<void()> fn) = 0;
// Higher cost than AddTask
template <typename F, typename... Args>
std::future<typename std::result_of<F(Args...)>::type> AddAwaitableTask(
F&& f, Args&&... args) {
using ReturnType = typename std::result_of<F(Args...)>::type;
std::function<ReturnType()> task =
std::bind(std::forward<F>(f), std::forward<Args>(args)...);
std::promise<ReturnType> prom;
std::future<ReturnType> res = prom.get_future();
AddTask([
t = std::move(task),
p = FakeCopyable<std::promise<ReturnType>>(std::move(prom))
]() mutable { p.Get().set_value(t()); });
return res;
}
// See WorkQueueOptions.track_task for details
// virtual void WaitQueueEmpty() = 0;
......@@ -102,6 +144,22 @@ class WorkQueueGroup {
virtual void AddTask(size_t queue_idx, std::function<void()> fn) = 0;
// Higher cost than AddTask
template <typename F, typename... Args>
std::future<typename std::result_of<F(Args...)>::type> AddAwaitableTask(
size_t queue_idx, F&& f, Args&&... args) {
using ReturnType = typename std::result_of<F(Args...)>::type;
std::function<ReturnType()> task =
std::bind(std::forward<F>(f), std::forward<Args>(args)...);
std::promise<ReturnType> prom;
std::future<ReturnType> res = prom.get_future();
AddTask(queue_idx, [
t = std::move(task),
p = FakeCopyable<std::promise<ReturnType>>(std::move(prom))
]() mutable { p.Get().set_value(t()); });
return res;
}
// See WorkQueueOptions.track_task for details
// virtual void WaitQueueGroupEmpty() = 0;
......
......@@ -60,11 +60,13 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
}
finished = true;
});
auto handle = work_queue->AddAwaitableTask([]() { return 1234; });
// WaitQueueEmpty
EXPECT_EQ(finished.load(), false);
events_waiter.WaitEvent();
EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum);
EXPECT_EQ(handle.get(), 1234);
}
TEST(WorkQueue, TestMultiThreadedWorkQueue) {
......@@ -146,6 +148,9 @@ TEST(WorkQueue, TestWorkQueueGroup) {
++counter;
}
});
int random_num = 123456;
auto handle =
queue_group->AddAwaitableTask(1, [random_num]() { return random_num; });
// WaitQueueGroupEmpty
events_waiter.WaitEvent();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
......@@ -154,4 +159,5 @@ TEST(WorkQueue, TestWorkQueueGroup) {
events_waiter.WaitEvent();
queue_group.reset();
EXPECT_EQ(events_waiter.WaitEvent(), paddle::framework::kQueueDestructEvent);
EXPECT_EQ(handle.get(), random_num);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册