未验证 提交 a53460aa 编写于 作者: L liutiexing 提交者: GitHub

Work queue group (#35470)

* Split Tracker and WorkQueue

* add WorkQueueGroup

* add unittest

* fix

* update

* update

* fix compile
上级 b3787d1b
cc_library(workqueue SRCS workqueue.cc) cc_library(workqueue SRCS workqueue.cc DEPS enforce)
cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry cc_library(interpretercore SRCS interpretercore.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
......
...@@ -157,7 +157,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -157,7 +157,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
garbages_.reset(new GarbageQueue()); garbages_.reset(new GarbageQueue());
max_memory_size_ = static_cast<size_t>(GetEagerDeletionThreshold()); max_memory_size_ = static_cast<size_t>(GetEagerDeletionThreshold());
cur_memory_size_ = 0; cur_memory_size_ = 0;
gc_queue_ = CreateSingleThreadedWorkQueue(); WorkQueueOptions options;
options.num_threads = 1;
gc_queue_ = CreateSingleThreadedWorkQueue(options);
feed_names_ = feed_names; feed_names_ = feed_names;
......
...@@ -19,45 +19,46 @@ ...@@ -19,45 +19,46 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CounterTracker { class TaskTracker {
public: public:
explicit CounterTracker(std::atomic<uint64_t>* counter, EventCount* ec) TaskTracker() : wait_empty_cv_(1) {}
: counter_(counter), ec_(ec) {
counter_->fetch_add(1, std::memory_order_relaxed);
}
~CounterTracker() { TaskTracker(const TaskTracker&) = delete;
if (counter_ != nullptr) {
if (1 == counter_->fetch_sub(1, std::memory_order_relaxed)) {
ec_->Notify(true);
}
}
}
CounterTracker(CounterTracker&& other) TaskTracker& operator=(const TaskTracker&) = delete;
: counter_(other.counter_), ec_(other.ec_) {
other.counter_ = nullptr;
other.ec_ = nullptr;
}
CounterTracker& operator=(CounterTracker&& other) { ~TaskTracker() = default;
counter_ = other.counter_;
ec_ = other.ec_;
other.counter_ = nullptr;
other.ec_ = nullptr;
return *this;
}
CounterTracker(const CounterTracker& other) void AddCounter() { num_tasks_.fetch_add(1, std::memory_order_relaxed); }
: counter_(other.counter_), ec_(other.ec_) {
counter_->fetch_add(1, std::memory_order_relaxed); void SubCounter() {
if (1 == num_tasks_.fetch_sub(1, std::memory_order_relaxed)) {
wait_empty_cv_.Notify(true);
}
} }
CounterTracker& operator=(const CounterTracker&) = delete; // only one user can wait at any time
void WaitTaskNumToZero() {
bool waiting = false;
if (!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
abort();
}
EventCount::Waiter* w = wait_empty_cv_.GetWaiter(0);
wait_empty_cv_.Prewait();
if (num_tasks_.load(std::memory_order_relaxed) == 0) {
wait_empty_cv_.CancelWait();
} else {
wait_empty_cv_.CommitWait(w);
}
wait_empty_.store(false);
}
private: private:
std::atomic<uint64_t>* counter_{nullptr}; std::atomic<uint64_t> num_tasks_{0};
EventCount* ec_{nullptr}; EventCount wait_empty_cv_;
std::atomic<bool> wait_empty_{false};
}; };
template <typename Environment> template <typename Environment>
...@@ -66,9 +67,6 @@ class ThreadPoolTempl { ...@@ -66,9 +67,6 @@ class ThreadPoolTempl {
typedef typename Environment::Task Task; typedef typename Environment::Task Task;
typedef RunQueue<Task, 1024> Queue; typedef RunQueue<Task, 1024> Queue;
explicit ThreadPoolTempl(int num_threads, Environment env = Environment())
: ThreadPoolTempl(num_threads, true, env) {}
ThreadPoolTempl(int num_threads, bool allow_spinning, ThreadPoolTempl(int num_threads, bool allow_spinning,
Environment env = Environment()) Environment env = Environment())
: env_(env), : env_(env),
...@@ -80,10 +78,7 @@ class ThreadPoolTempl { ...@@ -80,10 +78,7 @@ class ThreadPoolTempl {
spinning_(0), spinning_(0),
done_(false), done_(false),
cancelled_(false), cancelled_(false),
ec_(num_threads_), ec_(num_threads_) {
wait_empty_(false),
wait_empty_ec_(1),
num_tasks_(0) {
// Calculate coprimes of all numbers [1, num_threads]. // Calculate coprimes of all numbers [1, num_threads].
// Coprimes are used for random walks over all threads in Steal // Coprimes are used for random walks over all threads in Steal
// and NonEmptyQueueIndex. Iteration is based on the fact that if we take // and NonEmptyQueueIndex. Iteration is based on the fact that if we take
...@@ -146,15 +141,13 @@ class ThreadPoolTempl { ...@@ -146,15 +141,13 @@ class ThreadPoolTempl {
} }
void AddTaskWithHint(std::function<void()> fn, int start, int limit) { void AddTaskWithHint(std::function<void()> fn, int start, int limit) {
Task t = env_.CreateTask([ Task t = env_.CreateTask(std::move(fn));
task = std::move(fn), raii = CounterTracker(&num_tasks_, &wait_empty_ec_)
]() mutable { task(); });
PerThread* pt = GetPerThread(); PerThread* pt = GetPerThread();
if (pt->pool == this) { if (pt->pool == this) {
// Worker thread of this pool, push onto the thread's queue. // Worker thread of this pool, push onto the thread's queue.
Queue& q = thread_data_[pt->thread_id].queue; Queue& q = thread_data_[pt->thread_id].queue;
t = q.PushFront(std::move(t)); t = q.PushFront(std::move(t));
} else if (wait_empty_.load() == false) { } else {
// A free-standing thread (or worker of another pool), push onto a random // A free-standing thread (or worker of another pool), push onto a random
// queue. // queue.
assert(start < limit); assert(start < limit);
...@@ -179,29 +172,6 @@ class ThreadPoolTempl { ...@@ -179,29 +172,6 @@ class ThreadPoolTempl {
} }
} }
void WaitQueueEmpty() {
bool waiting = wait_empty_.load();
assert(waiting == false);
if (waiting ||
!wait_empty_.compare_exchange_strong(waiting, true,
std::memory_order_acquire)) {
abort();
}
EventCount::Waiter* w = wait_empty_ec_.GetWaiter(0);
wait_empty_ec_.Prewait();
if (num_tasks_.load() == 0) {
wait_empty_ec_.CancelWait();
} else {
wait_empty_ec_.CommitWait(w);
}
waiting = true;
if (!waiting ||
!wait_empty_.compare_exchange_strong(waiting, false,
std::memory_order_acquire)) {
abort();
}
}
void Cancel() { void Cancel() {
cancelled_ = true; cancelled_ = true;
done_ = true; done_ = true;
...@@ -300,10 +270,6 @@ class ThreadPoolTempl { ...@@ -300,10 +270,6 @@ class ThreadPoolTempl {
std::atomic<bool> cancelled_; std::atomic<bool> cancelled_;
EventCount ec_; EventCount ec_;
std::atomic<bool> wait_empty_;
EventCount wait_empty_ec_;
std::atomic<uint64_t> num_tasks_;
// Main worker thread loop. // Main worker thread loop.
void WorkerLoop(int thread_id) { void WorkerLoop(int thread_id) {
PerThread* pt = GetPerThread(); PerThread* pt = GetPerThread();
......
...@@ -194,7 +194,7 @@ class RunQueue { ...@@ -194,7 +194,7 @@ class RunQueue {
private: private:
static const unsigned kMask = kSize - 1; static const unsigned kMask = kSize - 1;
static const unsigned kMask2 = (kSize << 1) - 1; static const unsigned kMask2 = (kSize << 1) - 1;
struct Elem { struct alignas(64) Elem {
std::atomic<uint8_t> state; std::atomic<uint8_t> state;
Work w; Work w;
}; };
...@@ -212,8 +212,8 @@ class RunQueue { ...@@ -212,8 +212,8 @@ class RunQueue {
// position, these conditions would be indistinguishable); (2) obtain // position, these conditions would be indistinguishable); (2) obtain
// consistent snapshot of front_/back_ for Size operation using the // consistent snapshot of front_/back_ for Size operation using the
// modification counters. // modification counters.
std::atomic<unsigned> front_; alignas(64) std::atomic<unsigned> front_;
std::atomic<unsigned> back_; alignas(64) std::atomic<unsigned> back_;
Elem array_[kSize]; Elem array_[kSize];
// SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false, // SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false,
......
...@@ -6,63 +6,168 @@ ...@@ -6,63 +6,168 @@
#include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h" #include "paddle/fluid/framework/new_executor/nonblocking_threadpool.h"
#include "paddle/fluid/framework/new_executor/workqueue_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace {
class SingleThreadedWorkQueue : public WorkQueue { class WorkQueueImpl : public WorkQueue {
public: public:
SingleThreadedWorkQueue() : queue_(1) {} explicit WorkQueueImpl(const WorkQueueOptions& options)
: WorkQueue(options), queue_(nullptr), tracker_(nullptr) {
SingleThreadedWorkQueue(const SingleThreadedWorkQueue&) = delete; if (options_.track_task) {
tracker_ = new TaskTracker;
SingleThreadedWorkQueue& operator=(const SingleThreadedWorkQueue&) = delete; }
queue_ = new NonblockingThreadPool(options_.num_threads,
options_.allow_spinning);
}
virtual ~SingleThreadedWorkQueue() = default; virtual ~WorkQueueImpl() {
delete tracker_;
delete queue_;
}
void AddTask(std::function<void()> fn) override { void AddTask(std::function<void()> fn) override {
queue_.AddTask(std::move(fn)); if (tracker_ != nullptr) {
fn = [
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
]() mutable {
task();
};
}
queue_->AddTask(std::move(fn));
} }
void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } void WaitQueueEmpty() override {
if (tracker_ == nullptr) {
PADDLE_THROW(
platform::errors::Unavailable("set WorkQueueOptions.track_task = "
"true before call this interface."));
}
tracker_->WaitTaskNumToZero();
}
size_t NumThreads() override { return queue_.NumThreads(); } size_t NumThreads() const override { return queue_->NumThreads(); }
private: private:
NonblockingThreadPool queue_; NonblockingThreadPool* queue_;
TaskTracker* tracker_;
}; };
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue() { class WorkQueueGroupImpl : public WorkQueueGroup {
std::unique_ptr<WorkQueue> ptr(new SingleThreadedWorkQueue); public:
return std::move(ptr); explicit WorkQueueGroupImpl(
const std::vector<WorkQueueOptions>& queue_options);
~WorkQueueGroupImpl();
void AddTask(size_t queue_idx, std::function<void()> fn) override;
void WaitQueueGroupEmpty() override;
size_t QueueNumThreads(size_t queue_idx) const override;
size_t QueueGroupNumThreads() const override;
private:
std::vector<NonblockingThreadPool*> queues_;
NonblockingThreadPool* queues_storage_;
TaskTracker* tracker_;
};
WorkQueueGroupImpl::WorkQueueGroupImpl(
const std::vector<WorkQueueOptions>& queues_options)
: WorkQueueGroup(queues_options),
queues_storage_(nullptr),
tracker_(nullptr) {
size_t num_queues = queues_options_.size();
queues_.resize(num_queues);
void* buffer = malloc(sizeof(NonblockingThreadPool) * num_queues);
queues_storage_ = reinterpret_cast<NonblockingThreadPool*>(buffer);
for (size_t idx = 0; idx < num_queues; ++idx) {
const auto& options = queues_options_[idx];
if (options.track_task && tracker_ == nullptr) {
tracker_ = new TaskTracker;
}
queues_[idx] = new (&queues_storage_[idx])
NonblockingThreadPool(options.num_threads, options.allow_spinning);
}
} }
class MultiThreadedWorkQueue : public WorkQueue { WorkQueueGroupImpl::~WorkQueueGroupImpl() {
public: for (auto queue : queues_) {
explicit MultiThreadedWorkQueue(int num_threads) : queue_(num_threads) { queue->~NonblockingThreadPool();
assert(num_threads > 1);
} }
delete tracker_;
free(queues_storage_);
}
MultiThreadedWorkQueue(const MultiThreadedWorkQueue&) = delete; void WorkQueueGroupImpl::AddTask(size_t queue_idx, std::function<void()> fn) {
assert(queue_idx < queues_.size());
if (queues_options_.at(queue_idx).track_task) {
fn = [
task = std::move(fn), raii = CounterGuard<TaskTracker>(tracker_)
]() mutable {
task();
};
}
queues_[queue_idx]->AddTask(std::move(fn));
}
MultiThreadedWorkQueue& operator=(const MultiThreadedWorkQueue&) = delete; void WorkQueueGroupImpl::WaitQueueGroupEmpty() {
if (nullptr == tracker_) {
PADDLE_THROW(platform::errors::Unavailable(
"set WorkQueueOptions.track_task = true for at least one of queues "
"before call this interface."));
}
tracker_->WaitTaskNumToZero();
}
virtual ~MultiThreadedWorkQueue() = default; size_t WorkQueueGroupImpl::QueueNumThreads(size_t queue_idx) const {
assert(queue_idx < queues_.size());
return queues_.at(queue_idx)->NumThreads();
}
void AddTask(std::function<void()> fn) override { size_t WorkQueueGroupImpl::QueueGroupNumThreads() const {
queue_.AddTask(std::move(fn)); size_t total_num = 0;
for (auto queue : queues_) {
total_num += queue->NumThreads();
} }
return total_num;
}
void WaitQueueEmpty() override { queue_.WaitQueueEmpty(); } } // namespace
size_t NumThreads() override { return queue_.NumThreads(); } std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue(
const WorkQueueOptions& options) {
PADDLE_ENFORCE_EQ(options.num_threads, 1u,
platform::errors::InvalidArgument(
"For a SingleThreadedWorkQueue, "
"WorkQueueOptions.num_threads must equals to 1."));
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
return std::move(ptr);
}
private: std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(
NonblockingThreadPool queue_; const WorkQueueOptions& options) {
}; PADDLE_ENFORCE_GT(
options.num_threads, 1u,
platform::errors::InvalidArgument("For a MultiThreadedWorkQueue, "
"WorkQueueOptions.num_threads must be "
"greater than 1."));
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
return std::move(ptr);
}
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(int num_threads) { std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup(
std::unique_ptr<WorkQueue> ptr(new MultiThreadedWorkQueue(num_threads)); const std::vector<WorkQueueOptions>& queues_options) {
PADDLE_ENFORCE_GT(queues_options.size(), 1u,
platform::errors::InvalidArgument(
"For a WorkQueueGroup, the number of WorkQueueOptions "
"must be greater than 1."));
std::unique_ptr<WorkQueueGroup> ptr(new WorkQueueGroupImpl(queues_options));
return std::move(ptr); return std::move(ptr);
} }
......
...@@ -16,13 +16,20 @@ ...@@ -16,13 +16,20 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct WorkQueueOptions {
size_t num_threads{0};
bool allow_spinning{true};
bool track_task{false};
};
class WorkQueue { class WorkQueue {
public: public:
WorkQueue() = default; explicit WorkQueue(const WorkQueueOptions& options) : options_(options) {}
WorkQueue(const WorkQueue&) = delete; WorkQueue(const WorkQueue&) = delete;
...@@ -32,14 +39,49 @@ class WorkQueue { ...@@ -32,14 +39,49 @@ class WorkQueue {
virtual void AddTask(std::function<void()> fn) = 0; virtual void AddTask(std::function<void()> fn) = 0;
// set WorkQueueOptions.track_task = true before call this
// interface, otherwise will abort()
virtual void WaitQueueEmpty() = 0; virtual void WaitQueueEmpty() = 0;
virtual size_t NumThreads() = 0; virtual size_t NumThreads() const = 0;
protected:
WorkQueueOptions options_;
}; };
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue(); class WorkQueueGroup {
public:
explicit WorkQueueGroup(const std::vector<WorkQueueOptions>& queues_options)
: queues_options_(queues_options) {}
WorkQueueGroup(const WorkQueueGroup&) = delete;
WorkQueueGroup& operator=(const WorkQueueGroup&) = delete;
virtual ~WorkQueueGroup() = default;
virtual void AddTask(size_t queue_idx, std::function<void()> fn) = 0;
// set WorkQueueOptions.track_task = true for at least one of queues
// before call this interface, otherwise will abort()
virtual void WaitQueueGroupEmpty() = 0;
virtual size_t QueueNumThreads(size_t queue_idx) const = 0;
virtual size_t QueueGroupNumThreads() const = 0;
protected:
std::vector<WorkQueueOptions> queues_options_;
};
std::unique_ptr<WorkQueue> CreateSingleThreadedWorkQueue(
const WorkQueueOptions& options);
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(
const WorkQueueOptions& options);
std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(int num_threads); std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup(
const std::vector<WorkQueueOptions>& queues_options);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,13 +19,17 @@ ...@@ -19,13 +19,17 @@
TEST(WorkQueue, TestSingleThreadedWorkQueue) { TEST(WorkQueue, TestSingleThreadedWorkQueue) {
VLOG(1) << "In Test"; VLOG(1) << "In Test";
using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueue; using paddle::framework::WorkQueue;
using paddle::framework::CreateSingleThreadedWorkQueue; using paddle::framework::CreateSingleThreadedWorkQueue;
std::atomic<bool> finished{false}; std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0}; std::atomic<unsigned> counter{0};
constexpr unsigned kLoopNum = 1000000; constexpr unsigned kLoopNum = 1000000;
// CreateSingleThreadedWorkQueue // CreateSingleThreadedWorkQueue
std::unique_ptr<WorkQueue> work_queue = CreateSingleThreadedWorkQueue(); WorkQueueOptions options;
options.num_threads = 1;
options.track_task = true;
auto work_queue = CreateSingleThreadedWorkQueue(options);
// NumThreads // NumThreads
EXPECT_EQ(work_queue->NumThreads(), 1u); EXPECT_EQ(work_queue->NumThreads(), 1u);
// AddTask // AddTask
...@@ -46,14 +50,18 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) { ...@@ -46,14 +50,18 @@ TEST(WorkQueue, TestSingleThreadedWorkQueue) {
TEST(WorkQueue, TestMultiThreadedWorkQueue) { TEST(WorkQueue, TestMultiThreadedWorkQueue) {
VLOG(1) << "In Test"; VLOG(1) << "In Test";
using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueue; using paddle::framework::WorkQueue;
using paddle::framework::CreateMultiThreadedWorkQueue; using paddle::framework::CreateMultiThreadedWorkQueue;
std::atomic<bool> finished{false}; std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0}; std::atomic<unsigned> counter{0};
constexpr unsigned kExternalLoopNum = 100; constexpr unsigned kExternalLoopNum = 100;
constexpr unsigned kLoopNum = 1000000; constexpr unsigned kLoopNum = 1000000;
// CreateSingleThreadedWorkQueue // CreateMultiThreadedWorkQueue
std::unique_ptr<WorkQueue> work_queue = CreateMultiThreadedWorkQueue(10); WorkQueueOptions options;
options.num_threads = 10;
options.track_task = true;
auto work_queue = CreateMultiThreadedWorkQueue(options);
// NumThreads // NumThreads
EXPECT_EQ(work_queue->NumThreads(), 10u); EXPECT_EQ(work_queue->NumThreads(), 10u);
// AddTask // AddTask
...@@ -73,3 +81,42 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) { ...@@ -73,3 +81,42 @@ TEST(WorkQueue, TestMultiThreadedWorkQueue) {
EXPECT_EQ(finished.load(), true); EXPECT_EQ(finished.load(), true);
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum); EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum);
} }
TEST(WorkQueue, TestWorkQueueGroup) {
using paddle::framework::WorkQueueOptions;
using paddle::framework::WorkQueueGroup;
using paddle::framework::CreateWorkQueueGroup;
std::atomic<bool> finished{false};
std::atomic<unsigned> counter{0};
constexpr unsigned kExternalLoopNum = 100;
constexpr unsigned kLoopNum = 1000000;
// CreateMultiThreadedWorkQueue
WorkQueueOptions sq_options;
sq_options.num_threads = 1;
sq_options.track_task = true;
WorkQueueOptions mq_options;
mq_options.num_threads = 10;
mq_options.track_task = true;
auto queue_group = CreateWorkQueueGroup({sq_options, mq_options});
// NumThreads
EXPECT_EQ(queue_group->QueueNumThreads(0), 1u);
EXPECT_EQ(queue_group->QueueNumThreads(1), 10u);
EXPECT_EQ(queue_group->QueueGroupNumThreads(), 11u);
// AddTask
EXPECT_EQ(counter.load(), 0u);
for (unsigned i = 0; i < kExternalLoopNum; ++i) {
queue_group->AddTask(1, [&counter, &finished, kLoopNum]() {
for (unsigned i = 0; i < kLoopNum; ++i) {
++counter;
}
});
}
queue_group->AddTask(0, [&counter, &finished, kLoopNum]() {
for (unsigned i = 0; i < kLoopNum; ++i) {
++counter;
}
});
// WaitQueueGroupEmpty()
queue_group->WaitQueueGroupEmpty();
EXPECT_EQ(counter.load(), kLoopNum * kExternalLoopNum + kLoopNum);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstddef>
#include <cstdlib>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename Holder>
class CounterGuard {
public:
explicit CounterGuard(Holder* holder) : counter_holder_(holder) {
assert(holder != nullptr);
counter_holder_->AddCounter();
}
~CounterGuard() {
if (counter_holder_ != nullptr) {
counter_holder_->SubCounter();
}
}
CounterGuard(CounterGuard&& other) : counter_holder_(other.counter_holder_) {
other.counter_holder_ = nullptr;
}
CounterGuard& operator=(CounterGuard&& other) {
counter_holder_ = other.counter_holder_;
other.counter_holder_ = nullptr;
return *this;
}
// copy constructor deleted, we define this for std::function
// never use it directly
CounterGuard(const CounterGuard& other) {
PADDLE_THROW(platform::errors::Unavailable(
"Never use the copy constructor of CounterGuard."));
}
CounterGuard& operator=(const CounterGuard&) = delete;
private:
Holder* counter_holder_{nullptr};
};
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册