未验证 提交 78c884d7 编写于 作者: A Abhinav Arora 提交者: GitHub

Redesign channel implementation for Select Op (#8814)

* Redesign channel implementation for Select Op

* Remove unecessary header

* Remove unnecessary comments
上级 351795e8
...@@ -28,24 +28,19 @@ class Channel { ...@@ -28,24 +28,19 @@ class Channel {
virtual bool Send(T*) = 0; virtual bool Send(T*) = 0;
virtual bool Receive(T*) = 0; virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0; virtual size_t Cap() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual void Close() = 0; virtual void Close() = 0;
virtual ~Channel() {} virtual ~Channel() {}
}; };
// Forward declaration of channel implementations. // Forward declaration of channel implementations.
namespace details {
template <typename T> template <typename T>
class Buffered; class ChannelImpl;
template <typename T>
class UnBuffered;
} // namespace details
template <typename T> template <typename T>
Channel<T>* MakeChannel(size_t buffer_size) { Channel<T>* MakeChannel(size_t buffer_size) {
if (buffer_size > 0) { return new ChannelImpl<T>(buffer_size);
return new details::Buffered<T>(buffer_size);
}
return new details::UnBuffered<T>();
} }
template <typename T> template <typename T>
...@@ -89,6 +84,19 @@ class ChannelHolder { ...@@ -89,6 +84,19 @@ class ChannelHolder {
if (IsInitialized()) holder_->Close(); if (IsInitialized()) holder_->Close();
} }
size_t Cap() {
if (IsInitialized()) return holder_->Cap();
return -1;
}
void Lock() {
if (IsInitialized()) holder_->Lock();
}
void Unlock() {
if (IsInitialized()) holder_->Unlock();
}
inline bool IsInitialized() const { return holder_ != nullptr; } inline bool IsInitialized() const { return holder_ != nullptr; }
inline const std::type_index Type() { inline const std::type_index Type() {
...@@ -106,6 +114,9 @@ class ChannelHolder { ...@@ -106,6 +114,9 @@ class ChannelHolder {
virtual const std::type_index Type() const = 0; virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0; virtual void* Ptr() const = 0;
virtual void Close() = 0; virtual void Close() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual size_t Cap() = 0;
}; };
template <typename T> template <typename T>
...@@ -115,11 +126,28 @@ class ChannelHolder { ...@@ -115,11 +126,28 @@ class ChannelHolder {
} }
virtual const std::type_index Type() const { return type_; } virtual const std::type_index Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); } virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual void Close() { virtual void Close() {
if (channel_) channel_->Close(); if (channel_) channel_->Close();
} }
virtual size_t Cap() {
if (channel_)
return channel_->Cap();
else
return -1;
}
virtual void Lock() {
if (channel_) channel_->Lock();
}
virtual void Unlock() {
if (channel_) channel_->Unlock();
}
std::unique_ptr<Channel<T>> channel_; std::unique_ptr<Channel<T>> channel_;
const std::type_index type_; const std::type_index type_;
}; };
...@@ -131,5 +159,4 @@ class ChannelHolder { ...@@ -131,5 +159,4 @@ class ChannelHolder {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include "paddle/fluid/framework/details/buffered_channel.h" #include "paddle/fluid/framework/channel_impl.h"
#include "paddle/fluid/framework/details/unbuffered_channel.h"
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <atomic>
#include <condition_variable>
#include <deque>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename T>
class ChannelImpl : public paddle::framework::Channel<T> {
friend Channel<T> *paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T> *);
public:
virtual bool Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
virtual void Unlock();
virtual void Close();
ChannelImpl(size_t);
virtual ~ChannelImpl();
private:
struct QueueMessage {
T *data;
std::condition_variable_any cond;
bool chan_closed = false;
bool completed = false;
QueueMessage(T *item) : data(item) {}
void Wait(std::unique_lock<std::recursive_mutex> &lock) {
cond.wait(lock, [this]() { return completed; });
}
void Notify() {
completed = true;
cond.notify_all();
}
};
bool send_return(bool value) {
send_ctr--;
destructor_cond_.notify_all();
return value;
}
bool recv_return(bool value) {
recv_ctr--;
destructor_cond_.notify_all();
return value;
}
size_t cap_;
std::recursive_mutex mu_;
bool closed_;
std::deque<T> buf_;
std::deque<std::shared_ptr<QueueMessage>> recvq;
std::deque<std::shared_ptr<QueueMessage>> sendq;
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
std::condition_variable_any destructor_cond_;
};
template <typename T>
ChannelImpl<T>::ChannelImpl(size_t capacity)
: cap_(capacity), closed_(false), send_ctr(0), recv_ctr(0) {
PADDLE_ENFORCE_GE(capacity, 0);
}
template <typename T>
bool ChannelImpl<T>::Send(T *item) {
send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed, do nothing
if (closed_) {
lock.unlock();
// TODO(abhinavarora) Should panic on closed channel
return send_return(false);
}
// If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any
if (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
// Do the data transfer
*(m->data) = std::move(*item);
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return send_return(true);
}
// Unbuffered channel will always bypass this
// If buffered channel has space in buffer,
// write the element to the buffer.
if (buf_.size() < cap_) {
// Copy to buffer
buf_.push_back(std::move(*item));
// Release lock and return true
lock.unlock();
return send_return(true);
}
// Block on channel, because some receiver will complete
// the operation for us
auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m);
m->Wait(lock);
// TODO(abhinavarora) Should panic on closed channel
return send_return(!m->chan_closed);
}
template <typename T>
bool ChannelImpl<T>::Receive(T *item) {
recv_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed and buffer is empty or
// channel is unbuffered
if (closed_ && buf_.empty()) {
lock.unlock();
return recv_return(false);
}
// If there is a sender, directly receive the value we want
// from the sender, bypassing the channel buffer if any
if (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
// Do the data transfer
*item = std::move(*(m->data));
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return recv_return(true);
}
// If this is a buffered channel and there are items in buffer
if (buf_.size() > 0) {
// Directly read from buffer
*item = std::move(buf_.front());
buf_.pop_front();
// Release lock and return true
lock.unlock();
return recv_return(true);
}
// No sender available, block on this channel
// Some receiver will complete the option for us
auto m = std::make_shared<QueueMessage>(item);
recvq.push_back(m);
m->Wait(lock);
return recv_return(!m->chan_closed);
}
template <typename T>
void ChannelImpl<T>::Lock() {
mu_.lock();
}
template <typename T>
void ChannelImpl<T>::Unlock() {
mu_.unlock();
}
template <typename T>
void ChannelImpl<T>::Close() {
std::unique_lock<std::recursive_mutex> lock{mu_};
if (closed_) {
// TODO(abhinavarora): closing an already closed channel should panic
lock.unlock();
return;
}
closed_ = true;
// Empty the readers
while (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
m->chan_closed = true;
m->Notify();
}
// Empty the senders
while (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
m->chan_closed = true;
m->Notify();
}
}
template <typename T>
ChannelImpl<T>::~ChannelImpl() {
Close();
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
std::unique_lock<std::recursive_mutex> lock{mu_};
destructor_cond_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
}
} // namespace framework
} // namespace paddle
...@@ -23,8 +23,19 @@ using paddle::framework::Channel; ...@@ -23,8 +23,19 @@ using paddle::framework::Channel;
using paddle::framework::ChannelHolder; using paddle::framework::ChannelHolder;
using paddle::framework::MakeChannel; using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel; using paddle::framework::CloseChannel;
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered; TEST(Channel, ChannelCapacityTest) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
EXPECT_EQ(ch->Cap(), buffer_size);
CloseChannel(ch);
delete ch;
ch = MakeChannel<size_t>(0);
EXPECT_EQ(ch->Cap(), 0U);
CloseChannel(ch);
delete ch;
}
void RecevingOrderEqualToSendingOrder(Channel<int> *ch) { void RecevingOrderEqualToSendingOrder(Channel<int> *ch) {
unsigned sum_send = 0; unsigned sum_send = 0;
...@@ -35,38 +46,17 @@ void RecevingOrderEqualToSendingOrder(Channel<int> *ch) { ...@@ -35,38 +46,17 @@ void RecevingOrderEqualToSendingOrder(Channel<int> *ch) {
} }
}); });
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
int recv; int recv = 999;
EXPECT_EQ(ch->Receive(&recv), true); EXPECT_EQ(ch->Receive(&recv), true);
EXPECT_EQ(recv, i); EXPECT_EQ(recv, i);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(200));
CloseChannel(ch); CloseChannel(ch);
t.join(); t.join();
EXPECT_EQ(sum_send, 10U); EXPECT_EQ(sum_send, 10U);
delete ch; delete ch;
} }
TEST(Channel, MakeAndClose) {
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
{
// MakeChannel should return a buffered channel is buffer_size > 0.
auto ch = MakeChannel<int>(10);
EXPECT_NE(dynamic_cast<Buffered<int> *>(ch), nullptr);
EXPECT_EQ(dynamic_cast<UnBuffered<int> *>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
{
// MakeChannel should return an un-buffered channel is buffer_size = 0.
auto ch = MakeChannel<int>(0);
EXPECT_EQ(dynamic_cast<Buffered<int> *>(ch), nullptr);
EXPECT_NE(dynamic_cast<UnBuffered<int> *>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
}
TEST(Channel, SufficientBufferSizeDoesntBlock) { TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10; const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size); auto ch = MakeChannel<size_t>(buffer_size);
...@@ -166,7 +156,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { ...@@ -166,7 +156,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
const size_t buffer_size = 10; const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size); auto ch = MakeChannel<size_t>(buffer_size);
size_t sum = 0;
std::thread t([&]() { std::thread t([&]() {
// Try to write more than buffer size. // Try to write more than buffer size.
for (size_t i = 0; i < 2 * buffer_size; ++i) { for (size_t i = 0; i < 2 * buffer_size; ++i) {
...@@ -174,12 +163,9 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { ...@@ -174,12 +163,9 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations
else else
EXPECT_EQ(ch->Send(&i), false); EXPECT_EQ(ch->Send(&i), false);
sum += i;
} }
}); });
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
EXPECT_EQ(sum, 45U);
CloseChannel(ch); CloseChannel(ch);
t.join(); t.join();
delete ch; delete ch;
...@@ -211,7 +197,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) { ...@@ -211,7 +197,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
}, },
&thread_ended[i]); &thread_ended[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all the threads are blocked // Verify that all the threads are blocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -222,7 +208,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) { ...@@ -222,7 +208,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
// This should unblock all receivers // This should unblock all receivers
CloseChannel(ch); CloseChannel(ch);
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -232,10 +218,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) { ...@@ -232,10 +218,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
for (size_t i = 0; i < num_threads; i++) t[i].join(); for (size_t i = 0; i < num_threads; i++) t[i].join();
} }
void ChannelCloseUnblocksSendersTest(Channel<int> *ch) { void ChannelCloseUnblocksSendersTest(Channel<int> *ch, bool isBuffered) {
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
size_t num_threads = 5; size_t num_threads = 5;
std::thread t[num_threads]; std::thread t[num_threads];
bool thread_ended[num_threads]; bool thread_ended[num_threads];
...@@ -253,9 +236,9 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) { ...@@ -253,9 +236,9 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) {
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
if (dynamic_cast<Buffered<int> *>(ch)) { if (isBuffered) {
// If ch is Buffered, atleast 4 threads must be blocked. // If ch is Buffered, atleast 4 threads must be blocked.
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -272,14 +255,14 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) { ...@@ -272,14 +255,14 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) {
// This should unblock all senders // This should unblock all senders
CloseChannel(ch); CloseChannel(ch);
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], true); EXPECT_EQ(thread_ended[i], true);
} }
if (dynamic_cast<Buffered<int> *>(ch)) { if (isBuffered) {
// Verify that only 1 send was successful // Verify that only 1 send was successful
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -304,7 +287,7 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) { ...@@ -304,7 +287,7 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
// any senders waiting for channel to have write space // any senders waiting for channel to have write space
TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { TEST(Channel, BufferedChannelCloseUnblocksSendersTest) {
auto ch = MakeChannel<int>(1); auto ch = MakeChannel<int>(1);
ChannelCloseUnblocksSendersTest(ch); ChannelCloseUnblocksSendersTest(ch, true);
delete ch; delete ch;
} }
...@@ -320,7 +303,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { ...@@ -320,7 +303,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) {
// unblocks any senders waiting for senders // unblocks any senders waiting for senders
TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) { TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) {
auto ch = MakeChannel<int>(0); auto ch = MakeChannel<int>(0);
ChannelCloseUnblocksReceiversTest(ch); ChannelCloseUnblocksSendersTest(ch, false);
delete ch; delete ch;
} }
...@@ -342,7 +325,7 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) { ...@@ -342,7 +325,7 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) {
ch->Receive(&recv); ch->Receive(&recv);
EXPECT_EQ(recv, i); EXPECT_EQ(recv, i);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
EXPECT_EQ(sum_send, 3U); EXPECT_EQ(sum_send, 3U);
CloseChannel(ch); CloseChannel(ch);
...@@ -368,7 +351,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) { ...@@ -368,7 +351,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) {
ch->Send(&i); ch->Send(&i);
sum_send += i; sum_send += i;
} }
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
EXPECT_EQ(sum_send, 10U); EXPECT_EQ(sum_send, 10U);
EXPECT_EQ(sum_receive, 10U); EXPECT_EQ(sum_receive, 10U);
// send three more elements // send three more elements
...@@ -386,7 +369,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) { ...@@ -386,7 +369,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) {
// This tests that destroying a channel unblocks // This tests that destroying a channel unblocks
// any senders waiting for channel to have write space // any senders waiting for channel to have write space
void ChannelDestroyUnblockSenders(Channel<int> *ch) { void ChannelDestroyUnblockSenders(Channel<int> *ch, bool isBuffered) {
size_t num_threads = 5; size_t num_threads = 5;
std::thread t[num_threads]; std::thread t[num_threads];
bool thread_ended[num_threads]; bool thread_ended[num_threads];
...@@ -405,11 +388,9 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) { ...@@ -405,11 +388,9 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) {
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
bool is_buffered_channel = false;
if (dynamic_cast<Buffered<int> *>(ch)) is_buffered_channel = true;
if (is_buffered_channel) { if (isBuffered) {
// If channel is buffered, verify that atleast 4 threads are blocked // If channel is buffered, verify that atleast 4 threads are blocked
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -432,13 +413,13 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) { ...@@ -432,13 +413,13 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) {
EXPECT_EQ(thread_ended[i], true); EXPECT_EQ(thread_ended[i], true);
} }
// Count number of successfuld sends // Count number of successful sends
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
if (send_success[i]) ct++; if (send_success[i]) ct++;
} }
if (is_buffered_channel) { if (isBuffered) {
// Only 1 send must be successful // Only 1 send must be successful
EXPECT_EQ(ct, 1); EXPECT_EQ(ct, 1);
} else { } else {
...@@ -495,7 +476,7 @@ TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) { ...@@ -495,7 +476,7 @@ TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) {
TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) { TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) {
size_t buffer_size = 1; size_t buffer_size = 1;
auto ch = MakeChannel<int>(buffer_size); auto ch = MakeChannel<int>(buffer_size);
ChannelDestroyUnblockSenders(ch); ChannelDestroyUnblockSenders(ch, true);
} }
// This tests that destroying an unbuffered channel also unblocks // This tests that destroying an unbuffered channel also unblocks
...@@ -507,7 +488,20 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) { ...@@ -507,7 +488,20 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) {
TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) { TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) {
auto ch = MakeChannel<int>(0); auto ch = MakeChannel<int>(0);
ChannelDestroyUnblockSenders(ch); ChannelDestroyUnblockSenders(ch, false);
}
TEST(ChannelHolder, ChannelHolderCapacityTest) {
const size_t buffer_size = 10;
ChannelHolder *ch = new ChannelHolder();
ch->Reset<int>(buffer_size);
EXPECT_EQ(ch->Cap(), buffer_size);
delete ch;
ch = new ChannelHolder();
ch->Reset<int>(0);
EXPECT_EQ(ch->Cap(), 0U);
delete ch;
} }
void ChannelHolderSendReceive(ChannelHolder *ch) { void ChannelHolderSendReceive(ChannelHolder *ch) {
...@@ -641,7 +635,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { ...@@ -641,7 +635,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
}, },
&thread_ended[i]); &thread_ended[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all the threads are blocked // Verify that all the threads are blocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -652,7 +646,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { ...@@ -652,7 +646,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
// This should unblock all receivers // This should unblock all receivers
ch->close(); ch->close();
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -663,9 +657,6 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { ...@@ -663,9 +657,6 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
} }
void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
size_t num_threads = 5; size_t num_threads = 5;
std::thread t[num_threads]; std::thread t[num_threads];
bool thread_ended[num_threads]; bool thread_ended[num_threads];
...@@ -683,7 +674,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { ...@@ -683,7 +674,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
if (isBuffered) { if (isBuffered) {
// If ch is Buffered, atleast 4 threads must be blocked. // If ch is Buffered, atleast 4 threads must be blocked.
...@@ -702,7 +693,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { ...@@ -702,7 +693,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
// This should unblock all senders // This should unblock all senders
ch->close(); ch->close();
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -775,7 +766,7 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) { ...@@ -775,7 +766,7 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) {
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
if (isBuffered) { if (isBuffered) {
// If channel is buffered, verify that atleast 4 threads are blocked // If channel is buffered, verify that atleast 4 threads are blocked
int ct = 0; int ct = 0;
...@@ -836,7 +827,7 @@ void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) { ...@@ -836,7 +827,7 @@ void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) {
}, },
&thread_ended[i]); &thread_ended[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads are blocked // Verify that all threads are blocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <condition_variable>
#include <deque>
#include <mutex>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace details {
// Four of the properties of Buffered Channel:
// - A send to a full channel blocks temporarily until a receive from the
// channel or the channel is closed.
// - A receive from an empty channel blocks temporarily until a send to the
// channel or the channel is closed.
// - A send to a closed channel returns false immediately.
// - A receive from a closed channel returns false immediately.
template <typename T>
class Buffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return cap_; }
virtual void Close();
virtual ~Buffered();
private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::condition_variable destructor_cond_var_;
std::deque<T> channel_;
std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0);
}
void NotifyAllParticipants(std::unique_lock<std::mutex>*);
};
template <typename T>
bool Buffered<T>::Send(T* item) {
bool ret = false;
if (closed_) {
return ret;
}
send_ctr++;
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
if (!closed_) {
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
ret = true;
}
send_ctr--;
destructor_cond_var_.notify_one();
return ret;
}
template <typename T>
bool Buffered<T>::Receive(T* item) {
bool ret = false;
// Once the channel has been closed and all data has been consumed,
// just return false. Don't even try acquiring the mutex.
if (closed_ && channel_.empty()) {
return false;
}
recv_ctr++;
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
if (!channel_.empty()) {
*item = std::move(channel_.front());
channel_.pop_front();
full_cond_var_.notify_one();
ret = true;
}
recv_ctr--;
destructor_cond_var_.notify_one();
return ret;
}
template <typename T>
void Buffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
NotifyAllParticipants(&lock);
}
template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
channel_.clear();
NotifyAllParticipants(&lock);
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
lock.lock();
destructor_cond_var_.wait(
lock, [this]() { return send_ctr == 0 && recv_ctr == 0; });
}
template <typename T>
void Buffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_all();
empty_cond_var_.notify_all();
}
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "paddle/fluid/framework/channel.h"
namespace paddle {
namespace framework {
namespace details {
// Four of the properties of UnBuffered Channel:
// - A send to a channel blocks temporarily until a receive from the
// channel or the channel is closed.
// - A receive from a channel blocks temporarily until a send to the
// channel or the channel is closed.
// - A send to a closed channel returns false immediately.
// - A receive from a closed channel returns false immediately.
template <typename T>
class UnBuffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return 0; }
virtual void Close();
virtual ~UnBuffered();
private:
std::mutex mu_ch_;
// Mutex for readers and writers who are waiting for other reader
// and writer to complete execution
std::recursive_mutex mu_read_, mu_write_;
// reader_found_ is set true when a reader is ready to accept data
// writer_found_ is set true when a writer is ready to send data
// A transaction occurs only when both are true
std::atomic<bool> reader_found_{false}, writer_found_{false};
std::condition_variable cv_channel_;
std::condition_variable_any cv_reader_, cv_writer_, cv_destructor_;
T* item{nullptr};
std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
UnBuffered() : closed_(false) {}
void NotifyAllParticipants(std::unique_lock<std::mutex>*);
};
// This function implements the concept of how data should
// be sent from a writer to a reader.
template <typename T>
bool UnBuffered<T>::Send(T* data) {
bool ret = false;
if (closed_) {
return ret;
}
send_ctr++;
// Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true;
std::unique_lock<std::recursive_mutex> cv_lock(mu_write_);
// If writer comes first, it should wait till a reader arrives
cv_writer_.wait(cv_lock,
[this]() { return reader_found_ == true || closed_; });
cv_reader_.notify_one();
if (!closed_) {
std::unique_lock<std::mutex> channel_lock(mu_ch_);
item = data;
channel_lock.unlock();
cv_channel_.notify_one();
channel_lock.lock();
cv_channel_.wait(channel_lock,
[this]() { return item == nullptr || closed_; });
ret = true;
}
writer_found_ = false;
send_ctr--;
cv_destructor_.notify_one();
return ret;
}
// This function implements the concept of how
// data that was sent by a writer is read from a reader.
template <typename T>
bool UnBuffered<T>::Receive(T* data) {
bool ret = false;
// If channel is closed, we don't even want any reader to enter.
// Unlike a buffered channel, an unbuffered channel does not allow
// readers to read after closing because there is no buffer to be consumed.
if (closed_) return ret;
recv_ctr++;
// Prevent other readers from entering
std::unique_lock<std::recursive_mutex> read_lock{mu_read_};
reader_found_ = true;
std::unique_lock<std::recursive_mutex> cv_lock{mu_read_};
// If reader comes first, it should wait till a writer arrives
cv_reader_.wait(cv_lock,
[this]() { return writer_found_ == true || closed_; });
cv_writer_.notify_one();
if (!closed_) {
std::unique_lock<std::mutex> lock_ch{mu_ch_};
// Reader should wait for the writer to first write its data
cv_channel_.wait(lock_ch, [this]() { return item != nullptr || closed_; });
if (!closed_) {
*data = std::move(*item);
item = nullptr;
lock_ch.unlock();
ret = true;
}
cv_channel_.notify_one();
}
reader_found_ = false;
recv_ctr--;
cv_destructor_.notify_one();
return ret;
}
// This function implements the sequence of events
// that take place once the channel is closed.
template <typename T>
void UnBuffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr;
closed_ = true;
NotifyAllParticipants(&lock);
}
// This function implements the sequence of events
// that are executed once the object of an UnBuffered
// channel is destroyed.
template <typename T>
UnBuffered<T>::~UnBuffered() {
std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr;
closed_ = true;
NotifyAllParticipants(&lock);
lock.lock();
cv_destructor_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
}
// This function notifies all the readers, writers and
// the channel condition variables.
template <typename T>
void UnBuffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock();
cv_writer_.notify_all();
cv_channel_.notify_all();
cv_reader_.notify_all();
}
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册