From b60da6729fa2484506869bc29271761de91676b7 Mon Sep 17 00:00:00 2001 From: chengduo Date: Sat, 3 Feb 2018 23:32:56 +0800 Subject: [PATCH] Refine buffer channel (#8098) * refine buffer channel * refine Receive and Send * follow comments --- paddle/framework/channel.h | 4 +-- paddle/framework/details/buffered_channel.h | 25 ++++++++----------- paddle/framework/details/unbuffered_channel.h | 14 ++++++++--- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/paddle/framework/channel.h b/paddle/framework/channel.h index 0570980c5a4..b679387b112 100644 --- a/paddle/framework/channel.h +++ b/paddle/framework/channel.h @@ -23,8 +23,8 @@ namespace framework { template class Channel { public: - virtual void Send(T*) = 0; - virtual void Receive(T*) = 0; + virtual bool Send(T*) = 0; + virtual bool Receive(T*) = 0; virtual size_t Cap() = 0; virtual void Close() = 0; virtual ~Channel() {} diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h index 9c806461aa5..7ac234b8d42 100644 --- a/paddle/framework/details/buffered_channel.h +++ b/paddle/framework/details/buffered_channel.h @@ -30,8 +30,8 @@ class Buffered : public paddle::framework::Channel { friend void paddle::framework::CloseChannel(Channel*); public: - virtual void Send(T*); - virtual void Receive(T*); + virtual bool Send(T*); + virtual bool Receive(T*); virtual size_t Cap() { return cap_; } virtual void Close(); virtual ~Buffered(); @@ -48,33 +48,36 @@ class Buffered : public paddle::framework::Channel { PADDLE_ENFORCE_GT(cap, 0); } - void NotifyAllSenders(std::unique_lock*); void NotifyAllParticipants(std::unique_lock*); }; template -void Buffered::Send(T* item) { +bool Buffered::Send(T* item) { std::unique_lock lock(mu_); full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_ || closed_; }); + bool ret = false; if (!closed_) { channel_.push_back(std::move(*item)); lock.unlock(); empty_cond_var_.notify_one(); + ret = true; } + return ret; } template -void Buffered::Receive(T* item) { +bool Buffered::Receive(T* item) { std::unique_lock lock(mu_); empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; }); + bool ret = false; if (!closed_) { *item = std::move(channel_.front()); channel_.pop_front(); - NotifyAllSenders(&lock); - } else { - item = nullptr; + full_cond_var_.notify_one(); + ret = true; } + return ret; } template @@ -92,12 +95,6 @@ Buffered::~Buffered() { NotifyAllParticipants(&lock); } -template -void Buffered::NotifyAllSenders(std::unique_lock* lock) { - lock->unlock(); - full_cond_var_.notify_all(); -} - template void Buffered::NotifyAllParticipants(std::unique_lock* lock) { lock->unlock(); diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h index 0dc5afd7e57..f86a894bb4a 100644 --- a/paddle/framework/details/unbuffered_channel.h +++ b/paddle/framework/details/unbuffered_channel.h @@ -29,8 +29,8 @@ class UnBuffered : public paddle::framework::Channel { friend void paddle::framework::CloseChannel(Channel*); public: - virtual void Send(T*); - virtual void Receive(T*); + virtual bool Send(T*); + virtual bool Receive(T*); virtual size_t Cap() { return 0; } virtual void Close(); virtual ~UnBuffered(); @@ -57,7 +57,7 @@ class UnBuffered : public paddle::framework::Channel { // This function implements the concept of how data should // be sent from a writer to a reader. template -void UnBuffered::Send(T* data) { +bool UnBuffered::Send(T* data) { // Prevent other writers from entering std::unique_lock writer_lock(mu_write_); writer_found_ = true; @@ -66,6 +66,7 @@ void UnBuffered::Send(T* data) { cv_writer_.wait(cv_lock, [this]() { return reader_found_ == true || closed_; }); cv_reader_.notify_one(); + bool ret = false; if (!closed_) { std::unique_lock channel_lock(mu_ch_); item = data; @@ -74,14 +75,16 @@ void UnBuffered::Send(T* data) { channel_lock.lock(); cv_channel_.wait(channel_lock, [this]() { return item == nullptr || closed_; }); + ret = true; } writer_found_ = false; + return ret; } // This function implements the concept of how // data that was sent by a writer is read from a reader. template -void UnBuffered::Receive(T* data) { +bool UnBuffered::Receive(T* data) { // Prevent other readers from entering std::unique_lock read_lock{mu_read_}; reader_found_ = true; @@ -90,6 +93,7 @@ void UnBuffered::Receive(T* data) { cv_reader_.wait(cv_lock, [this]() { return writer_found_ == true || closed_; }); cv_writer_.notify_one(); + bool ret = false; if (!closed_) { std::unique_lock lock_ch{mu_ch_}; // Reader should wait for the writer to first write its data @@ -98,10 +102,12 @@ void UnBuffered::Receive(T* data) { *data = std::move(*item); item = nullptr; lock_ch.unlock(); + ret = true; } cv_channel_.notify_one(); } reader_found_ = false; + return ret; } // This function implements the sequence of events -- GitLab