From 78c884d7a70205a894ca7f446bdae2ace87f24e1 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 6 Mar 2018 12:10:40 -0800 Subject: [PATCH] Redesign channel implementation for Select Op (#8814) * Redesign channel implementation for Select Op * Remove unecessary header * Remove unnecessary comments --- paddle/fluid/framework/channel.h | 49 +++- paddle/fluid/framework/channel_impl.h | 229 ++++++++++++++++++ paddle/fluid/framework/channel_test.cc | 115 ++++----- .../framework/details/buffered_channel.h | 142 ----------- .../framework/details/unbuffered_channel.h | 174 ------------- 5 files changed, 320 insertions(+), 389 deletions(-) create mode 100644 paddle/fluid/framework/channel_impl.h delete mode 100644 paddle/fluid/framework/details/buffered_channel.h delete mode 100644 paddle/fluid/framework/details/unbuffered_channel.h diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index bda1bfb23b1..9f8fb12098d 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -28,24 +28,19 @@ class Channel { virtual bool Send(T*) = 0; virtual bool Receive(T*) = 0; virtual size_t Cap() = 0; + virtual void Lock() = 0; + virtual void Unlock() = 0; virtual void Close() = 0; virtual ~Channel() {} }; // Forward declaration of channel implementations. -namespace details { template -class Buffered; -template -class UnBuffered; -} // namespace details +class ChannelImpl; template Channel* MakeChannel(size_t buffer_size) { - if (buffer_size > 0) { - return new details::Buffered(buffer_size); - } - return new details::UnBuffered(); + return new ChannelImpl(buffer_size); } template @@ -89,6 +84,19 @@ class ChannelHolder { if (IsInitialized()) holder_->Close(); } + size_t Cap() { + if (IsInitialized()) return holder_->Cap(); + return -1; + } + + void Lock() { + if (IsInitialized()) holder_->Lock(); + } + + void Unlock() { + if (IsInitialized()) holder_->Unlock(); + } + inline bool IsInitialized() const { return holder_ != nullptr; } inline const std::type_index Type() { @@ -106,6 +114,9 @@ class ChannelHolder { virtual const std::type_index Type() const = 0; virtual void* Ptr() const = 0; virtual void Close() = 0; + virtual void Lock() = 0; + virtual void Unlock() = 0; + virtual size_t Cap() = 0; }; template @@ -115,11 +126,28 @@ class ChannelHolder { } virtual const std::type_index Type() const { return type_; } + virtual void* Ptr() const { return static_cast(channel_.get()); } + virtual void Close() { if (channel_) channel_->Close(); } + virtual size_t Cap() { + if (channel_) + return channel_->Cap(); + else + return -1; + } + + virtual void Lock() { + if (channel_) channel_->Lock(); + } + + virtual void Unlock() { + if (channel_) channel_->Unlock(); + } + std::unique_ptr> channel_; const std::type_index type_; }; @@ -131,5 +159,4 @@ class ChannelHolder { } // namespace framework } // namespace paddle -#include "paddle/fluid/framework/details/buffered_channel.h" -#include "paddle/fluid/framework/details/unbuffered_channel.h" +#include "paddle/fluid/framework/channel_impl.h" diff --git a/paddle/fluid/framework/channel_impl.h b/paddle/fluid/framework/channel_impl.h new file mode 100644 index 00000000000..a4561031fd8 --- /dev/null +++ b/paddle/fluid/framework/channel_impl.h @@ -0,0 +1,229 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include // for size_t +#include +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { + +template +class ChannelImpl : public paddle::framework::Channel { + friend Channel *paddle::framework::MakeChannel(size_t); + friend void paddle::framework::CloseChannel(Channel *); + + public: + virtual bool Send(T *); + virtual bool Receive(T *); + virtual size_t Cap() { return cap_; } + virtual void Lock(); + virtual void Unlock(); + virtual void Close(); + + ChannelImpl(size_t); + virtual ~ChannelImpl(); + + private: + struct QueueMessage { + T *data; + std::condition_variable_any cond; + bool chan_closed = false; + bool completed = false; + + QueueMessage(T *item) : data(item) {} + + void Wait(std::unique_lock &lock) { + cond.wait(lock, [this]() { return completed; }); + } + + void Notify() { + completed = true; + cond.notify_all(); + } + }; + + bool send_return(bool value) { + send_ctr--; + destructor_cond_.notify_all(); + return value; + } + + bool recv_return(bool value) { + recv_ctr--; + destructor_cond_.notify_all(); + return value; + } + + size_t cap_; + std::recursive_mutex mu_; + bool closed_; + std::deque buf_; + std::deque> recvq; + std::deque> sendq; + std::atomic send_ctr{0}; + std::atomic recv_ctr{0}; + std::condition_variable_any destructor_cond_; +}; + +template +ChannelImpl::ChannelImpl(size_t capacity) + : cap_(capacity), closed_(false), send_ctr(0), recv_ctr(0) { + PADDLE_ENFORCE_GE(capacity, 0); +} + +template +bool ChannelImpl::Send(T *item) { + send_ctr++; + std::unique_lock lock{mu_}; + + // If channel is closed, do nothing + if (closed_) { + lock.unlock(); + // TODO(abhinavarora) Should panic on closed channel + return send_return(false); + } + + // If there is a receiver, directly pass the value we want + // to send to the receiver, bypassing the channel buffer if any + if (!recvq.empty()) { + std::shared_ptr m = recvq.front(); + recvq.pop_front(); + // Do the data transfer + *(m->data) = std::move(*item); + // Wake up the blocked process and unlock + m->Notify(); + lock.unlock(); + return send_return(true); + } + + // Unbuffered channel will always bypass this + // If buffered channel has space in buffer, + // write the element to the buffer. + if (buf_.size() < cap_) { + // Copy to buffer + buf_.push_back(std::move(*item)); + // Release lock and return true + lock.unlock(); + return send_return(true); + } + + // Block on channel, because some receiver will complete + // the operation for us + auto m = std::make_shared(item); + sendq.push_back(m); + m->Wait(lock); + // TODO(abhinavarora) Should panic on closed channel + return send_return(!m->chan_closed); +} + +template +bool ChannelImpl::Receive(T *item) { + recv_ctr++; + std::unique_lock lock{mu_}; + + // If channel is closed and buffer is empty or + // channel is unbuffered + if (closed_ && buf_.empty()) { + lock.unlock(); + return recv_return(false); + } + + // If there is a sender, directly receive the value we want + // from the sender, bypassing the channel buffer if any + if (!sendq.empty()) { + std::shared_ptr m = sendq.front(); + sendq.pop_front(); + // Do the data transfer + *item = std::move(*(m->data)); + // Wake up the blocked process and unlock + m->Notify(); + lock.unlock(); + return recv_return(true); + } + + // If this is a buffered channel and there are items in buffer + if (buf_.size() > 0) { + // Directly read from buffer + *item = std::move(buf_.front()); + buf_.pop_front(); + // Release lock and return true + lock.unlock(); + return recv_return(true); + } + + // No sender available, block on this channel + // Some receiver will complete the option for us + auto m = std::make_shared(item); + recvq.push_back(m); + m->Wait(lock); + + return recv_return(!m->chan_closed); +} + +template +void ChannelImpl::Lock() { + mu_.lock(); +} + +template +void ChannelImpl::Unlock() { + mu_.unlock(); +} + +template +void ChannelImpl::Close() { + std::unique_lock lock{mu_}; + + if (closed_) { + // TODO(abhinavarora): closing an already closed channel should panic + lock.unlock(); + return; + } + + closed_ = true; + + // Empty the readers + while (!recvq.empty()) { + std::shared_ptr m = recvq.front(); + recvq.pop_front(); + m->chan_closed = true; + m->Notify(); + } + + // Empty the senders + while (!sendq.empty()) { + std::shared_ptr m = sendq.front(); + sendq.pop_front(); + m->chan_closed = true; + m->Notify(); + } +} + +template +ChannelImpl::~ChannelImpl() { + Close(); + // The destructor must wait for all readers and writers to complete their task + // The channel has been closed, so we will not accept new readers and writers + std::unique_lock lock{mu_}; + destructor_cond_.wait(lock, + [this]() { return send_ctr == 0 && recv_ctr == 0; }); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc index 695169fcb9e..edfb41c7248 100644 --- a/paddle/fluid/framework/channel_test.cc +++ b/paddle/fluid/framework/channel_test.cc @@ -23,8 +23,19 @@ using paddle::framework::Channel; using paddle::framework::ChannelHolder; using paddle::framework::MakeChannel; using paddle::framework::CloseChannel; -using paddle::framework::details::Buffered; -using paddle::framework::details::UnBuffered; + +TEST(Channel, ChannelCapacityTest) { + const size_t buffer_size = 10; + auto ch = MakeChannel(buffer_size); + EXPECT_EQ(ch->Cap(), buffer_size); + CloseChannel(ch); + delete ch; + + ch = MakeChannel(0); + EXPECT_EQ(ch->Cap(), 0U); + CloseChannel(ch); + delete ch; +} void RecevingOrderEqualToSendingOrder(Channel *ch) { unsigned sum_send = 0; @@ -35,38 +46,17 @@ void RecevingOrderEqualToSendingOrder(Channel *ch) { } }); for (int i = 0; i < 5; i++) { - int recv; + int recv = 999; EXPECT_EQ(ch->Receive(&recv), true); EXPECT_EQ(recv, i); } - + std::this_thread::sleep_for(std::chrono::milliseconds(200)); CloseChannel(ch); t.join(); EXPECT_EQ(sum_send, 10U); delete ch; } -TEST(Channel, MakeAndClose) { - using paddle::framework::details::Buffered; - using paddle::framework::details::UnBuffered; - { - // MakeChannel should return a buffered channel is buffer_size > 0. - auto ch = MakeChannel(10); - EXPECT_NE(dynamic_cast *>(ch), nullptr); - EXPECT_EQ(dynamic_cast *>(ch), nullptr); - CloseChannel(ch); - delete ch; - } - { - // MakeChannel should return an un-buffered channel is buffer_size = 0. - auto ch = MakeChannel(0); - EXPECT_EQ(dynamic_cast *>(ch), nullptr); - EXPECT_NE(dynamic_cast *>(ch), nullptr); - CloseChannel(ch); - delete ch; - } -} - TEST(Channel, SufficientBufferSizeDoesntBlock) { const size_t buffer_size = 10; auto ch = MakeChannel(buffer_size); @@ -166,7 +156,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { const size_t buffer_size = 10; auto ch = MakeChannel(buffer_size); - size_t sum = 0; std::thread t([&]() { // Try to write more than buffer size. for (size_t i = 0; i < 2 * buffer_size; ++i) { @@ -174,12 +163,9 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations else EXPECT_EQ(ch->Send(&i), false); - sum += i; } }); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec - EXPECT_EQ(sum, 45U); - + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec CloseChannel(ch); t.join(); delete ch; @@ -211,7 +197,7 @@ void ChannelCloseUnblocksReceiversTest(Channel *ch) { }, &thread_ended[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec // Verify that all the threads are blocked for (size_t i = 0; i < num_threads; i++) { @@ -222,7 +208,7 @@ void ChannelCloseUnblocksReceiversTest(Channel *ch) { // This should unblock all receivers CloseChannel(ch); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec // Verify that all threads got unblocked for (size_t i = 0; i < num_threads; i++) { @@ -232,10 +218,7 @@ void ChannelCloseUnblocksReceiversTest(Channel *ch) { for (size_t i = 0; i < num_threads; i++) t[i].join(); } -void ChannelCloseUnblocksSendersTest(Channel *ch) { - using paddle::framework::details::Buffered; - using paddle::framework::details::UnBuffered; - +void ChannelCloseUnblocksSendersTest(Channel *ch, bool isBuffered) { size_t num_threads = 5; std::thread t[num_threads]; bool thread_ended[num_threads]; @@ -253,9 +236,9 @@ void ChannelCloseUnblocksSendersTest(Channel *ch) { }, &thread_ended[i], &send_success[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait - if (dynamic_cast *>(ch)) { + if (isBuffered) { // If ch is Buffered, atleast 4 threads must be blocked. int ct = 0; for (size_t i = 0; i < num_threads; i++) { @@ -272,14 +255,14 @@ void ChannelCloseUnblocksSendersTest(Channel *ch) { // This should unblock all senders CloseChannel(ch); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait // Verify that all threads got unblocked for (size_t i = 0; i < num_threads; i++) { EXPECT_EQ(thread_ended[i], true); } - if (dynamic_cast *>(ch)) { + if (isBuffered) { // Verify that only 1 send was successful int ct = 0; for (size_t i = 0; i < num_threads; i++) { @@ -304,7 +287,7 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) { // any senders waiting for channel to have write space TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { auto ch = MakeChannel(1); - ChannelCloseUnblocksSendersTest(ch); + ChannelCloseUnblocksSendersTest(ch, true); delete ch; } @@ -320,7 +303,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { // unblocks any senders waiting for senders TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) { auto ch = MakeChannel(0); - ChannelCloseUnblocksReceiversTest(ch); + ChannelCloseUnblocksSendersTest(ch, false); delete ch; } @@ -342,7 +325,7 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) { ch->Receive(&recv); EXPECT_EQ(recv, i); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec EXPECT_EQ(sum_send, 3U); CloseChannel(ch); @@ -368,7 +351,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) { ch->Send(&i); sum_send += i; } - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec EXPECT_EQ(sum_send, 10U); EXPECT_EQ(sum_receive, 10U); // send three more elements @@ -386,7 +369,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) { // This tests that destroying a channel unblocks // any senders waiting for channel to have write space -void ChannelDestroyUnblockSenders(Channel *ch) { +void ChannelDestroyUnblockSenders(Channel *ch, bool isBuffered) { size_t num_threads = 5; std::thread t[num_threads]; bool thread_ended[num_threads]; @@ -405,11 +388,9 @@ void ChannelDestroyUnblockSenders(Channel *ch) { &thread_ended[i], &send_success[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec - bool is_buffered_channel = false; - if (dynamic_cast *>(ch)) is_buffered_channel = true; + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec - if (is_buffered_channel) { + if (isBuffered) { // If channel is buffered, verify that atleast 4 threads are blocked int ct = 0; for (size_t i = 0; i < num_threads; i++) { @@ -432,13 +413,13 @@ void ChannelDestroyUnblockSenders(Channel *ch) { EXPECT_EQ(thread_ended[i], true); } - // Count number of successfuld sends + // Count number of successful sends int ct = 0; for (size_t i = 0; i < num_threads; i++) { if (send_success[i]) ct++; } - if (is_buffered_channel) { + if (isBuffered) { // Only 1 send must be successful EXPECT_EQ(ct, 1); } else { @@ -495,7 +476,7 @@ TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) { TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) { size_t buffer_size = 1; auto ch = MakeChannel(buffer_size); - ChannelDestroyUnblockSenders(ch); + ChannelDestroyUnblockSenders(ch, true); } // This tests that destroying an unbuffered channel also unblocks @@ -507,7 +488,20 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) { TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) { auto ch = MakeChannel(0); - ChannelDestroyUnblockSenders(ch); + ChannelDestroyUnblockSenders(ch, false); +} + +TEST(ChannelHolder, ChannelHolderCapacityTest) { + const size_t buffer_size = 10; + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(buffer_size); + EXPECT_EQ(ch->Cap(), buffer_size); + delete ch; + + ch = new ChannelHolder(); + ch->Reset(0); + EXPECT_EQ(ch->Cap(), 0U); + delete ch; } void ChannelHolderSendReceive(ChannelHolder *ch) { @@ -641,7 +635,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { }, &thread_ended[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec // Verify that all the threads are blocked for (size_t i = 0; i < num_threads; i++) { @@ -652,7 +646,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { // This should unblock all receivers ch->close(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec // Verify that all threads got unblocked for (size_t i = 0; i < num_threads; i++) { @@ -663,9 +657,6 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { } void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { - using paddle::framework::details::Buffered; - using paddle::framework::details::UnBuffered; - size_t num_threads = 5; std::thread t[num_threads]; bool thread_ended[num_threads]; @@ -683,7 +674,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { }, &thread_ended[i], &send_success[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait if (isBuffered) { // If ch is Buffered, atleast 4 threads must be blocked. @@ -702,7 +693,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { // This should unblock all senders ch->close(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait // Verify that all threads got unblocked for (size_t i = 0; i < num_threads; i++) { @@ -775,7 +766,7 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) { &thread_ended[i], &send_success[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec if (isBuffered) { // If channel is buffered, verify that atleast 4 threads are blocked int ct = 0; @@ -836,7 +827,7 @@ void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) { }, &thread_ended[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait // Verify that all threads are blocked for (size_t i = 0; i < num_threads; i++) { diff --git a/paddle/fluid/framework/details/buffered_channel.h b/paddle/fluid/framework/details/buffered_channel.h deleted file mode 100644 index 88faf3acf7c..00000000000 --- a/paddle/fluid/framework/details/buffered_channel.h +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include - -#include "paddle/fluid/framework/channel.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace details { - -// Four of the properties of Buffered Channel: -// - A send to a full channel blocks temporarily until a receive from the -// channel or the channel is closed. -// - A receive from an empty channel blocks temporarily until a send to the -// channel or the channel is closed. -// - A send to a closed channel returns false immediately. -// - A receive from a closed channel returns false immediately. - -template -class Buffered : public paddle::framework::Channel { - friend Channel* paddle::framework::MakeChannel(size_t); - friend void paddle::framework::CloseChannel(Channel*); - - public: - virtual bool Send(T*); - virtual bool Receive(T*); - virtual size_t Cap() { return cap_; } - virtual void Close(); - virtual ~Buffered(); - - private: - size_t cap_; - std::mutex mu_; - std::condition_variable empty_cond_var_; - std::condition_variable full_cond_var_; - std::condition_variable destructor_cond_var_; - std::deque channel_; - std::atomic closed_{false}; - std::atomic send_ctr{0}; - std::atomic recv_ctr{0}; - - Buffered(size_t cap) : cap_(cap), closed_(false) { - PADDLE_ENFORCE_GT(cap, 0); - } - - void NotifyAllParticipants(std::unique_lock*); -}; - -template -bool Buffered::Send(T* item) { - bool ret = false; - if (closed_) { - return ret; - } - send_ctr++; - std::unique_lock lock(mu_); - full_cond_var_.wait(lock, - [this]() { return channel_.size() < cap_ || closed_; }); - if (!closed_) { - channel_.push_back(std::move(*item)); - lock.unlock(); - empty_cond_var_.notify_one(); - ret = true; - } - send_ctr--; - destructor_cond_var_.notify_one(); - return ret; -} - -template -bool Buffered::Receive(T* item) { - bool ret = false; - // Once the channel has been closed and all data has been consumed, - // just return false. Don't even try acquiring the mutex. - if (closed_ && channel_.empty()) { - return false; - } - recv_ctr++; - std::unique_lock lock(mu_); - empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; }); - if (!channel_.empty()) { - *item = std::move(channel_.front()); - channel_.pop_front(); - full_cond_var_.notify_one(); - ret = true; - } - recv_ctr--; - destructor_cond_var_.notify_one(); - return ret; -} - -template -void Buffered::Close() { - if (closed_) { - return; - } - std::unique_lock lock(mu_); - closed_ = true; - NotifyAllParticipants(&lock); -} - -template -Buffered::~Buffered() { - std::unique_lock lock(mu_); - closed_ = true; - channel_.clear(); - NotifyAllParticipants(&lock); - - // The destructor must wait for all readers and writers to complete their task - // The channel has been closed, so we will not accept new readers and writers - lock.lock(); - destructor_cond_var_.wait( - lock, [this]() { return send_ctr == 0 && recv_ctr == 0; }); -} - -template -void Buffered::NotifyAllParticipants(std::unique_lock* lock) { - lock->unlock(); - full_cond_var_.notify_all(); - empty_cond_var_.notify_all(); -} - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/unbuffered_channel.h b/paddle/fluid/framework/details/unbuffered_channel.h deleted file mode 100644 index 5c9424928cb..00000000000 --- a/paddle/fluid/framework/details/unbuffered_channel.h +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -#include "paddle/fluid/framework/channel.h" - -namespace paddle { -namespace framework { -namespace details { - -// Four of the properties of UnBuffered Channel: -// - A send to a channel blocks temporarily until a receive from the -// channel or the channel is closed. -// - A receive from a channel blocks temporarily until a send to the -// channel or the channel is closed. -// - A send to a closed channel returns false immediately. -// - A receive from a closed channel returns false immediately. -template -class UnBuffered : public paddle::framework::Channel { - friend Channel* paddle::framework::MakeChannel(size_t); - friend void paddle::framework::CloseChannel(Channel*); - - public: - virtual bool Send(T*); - virtual bool Receive(T*); - virtual size_t Cap() { return 0; } - virtual void Close(); - virtual ~UnBuffered(); - - private: - std::mutex mu_ch_; - // Mutex for readers and writers who are waiting for other reader - // and writer to complete execution - std::recursive_mutex mu_read_, mu_write_; - // reader_found_ is set true when a reader is ready to accept data - // writer_found_ is set true when a writer is ready to send data - // A transaction occurs only when both are true - std::atomic reader_found_{false}, writer_found_{false}; - std::condition_variable cv_channel_; - std::condition_variable_any cv_reader_, cv_writer_, cv_destructor_; - T* item{nullptr}; - std::atomic closed_{false}; - std::atomic send_ctr{0}; - std::atomic recv_ctr{0}; - - UnBuffered() : closed_(false) {} - - void NotifyAllParticipants(std::unique_lock*); -}; - -// This function implements the concept of how data should -// be sent from a writer to a reader. -template -bool UnBuffered::Send(T* data) { - bool ret = false; - if (closed_) { - return ret; - } - send_ctr++; - // Prevent other writers from entering - std::unique_lock writer_lock(mu_write_); - writer_found_ = true; - std::unique_lock cv_lock(mu_write_); - // If writer comes first, it should wait till a reader arrives - cv_writer_.wait(cv_lock, - [this]() { return reader_found_ == true || closed_; }); - cv_reader_.notify_one(); - if (!closed_) { - std::unique_lock channel_lock(mu_ch_); - item = data; - channel_lock.unlock(); - cv_channel_.notify_one(); - channel_lock.lock(); - cv_channel_.wait(channel_lock, - [this]() { return item == nullptr || closed_; }); - ret = true; - } - writer_found_ = false; - send_ctr--; - cv_destructor_.notify_one(); - return ret; -} - -// This function implements the concept of how -// data that was sent by a writer is read from a reader. -template -bool UnBuffered::Receive(T* data) { - bool ret = false; - // If channel is closed, we don't even want any reader to enter. - // Unlike a buffered channel, an unbuffered channel does not allow - // readers to read after closing because there is no buffer to be consumed. - if (closed_) return ret; - recv_ctr++; - // Prevent other readers from entering - std::unique_lock read_lock{mu_read_}; - reader_found_ = true; - std::unique_lock cv_lock{mu_read_}; - // If reader comes first, it should wait till a writer arrives - cv_reader_.wait(cv_lock, - [this]() { return writer_found_ == true || closed_; }); - cv_writer_.notify_one(); - if (!closed_) { - std::unique_lock lock_ch{mu_ch_}; - // Reader should wait for the writer to first write its data - cv_channel_.wait(lock_ch, [this]() { return item != nullptr || closed_; }); - if (!closed_) { - *data = std::move(*item); - item = nullptr; - lock_ch.unlock(); - ret = true; - } - cv_channel_.notify_one(); - } - reader_found_ = false; - recv_ctr--; - cv_destructor_.notify_one(); - return ret; -} - -// This function implements the sequence of events -// that take place once the channel is closed. -template -void UnBuffered::Close() { - if (closed_) { - return; - } - std::unique_lock lock(mu_ch_); - item = nullptr; - closed_ = true; - NotifyAllParticipants(&lock); -} - -// This function implements the sequence of events -// that are executed once the object of an UnBuffered -// channel is destroyed. -template -UnBuffered::~UnBuffered() { - std::unique_lock lock(mu_ch_); - item = nullptr; - closed_ = true; - NotifyAllParticipants(&lock); - lock.lock(); - cv_destructor_.wait(lock, - [this]() { return send_ctr == 0 && recv_ctr == 0; }); -} - -// This function notifies all the readers, writers and -// the channel condition variables. -template -void UnBuffered::NotifyAllParticipants(std::unique_lock* lock) { - lock->unlock(); - cv_writer_.notify_all(); - cv_channel_.notify_all(); - cv_reader_.notify_all(); -} - -} // namespace details -} // namespace framework -} // namespace paddle -- GitLab