未验证 提交 a1fc5701 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix for program crash when destructor is called before channel close with...

Fix for program crash when destructor is called before channel close with blocked readers/writers (#8197)

* Fix destructor crash and add unit tests

* Fix typo in unit test

* Reword comments

* Make close channel a generic test

* Refactoring unit tests

* Fix method name
上级 b1869f16
...@@ -22,6 +22,8 @@ limitations under the License. */ ...@@ -22,6 +22,8 @@ limitations under the License. */
using paddle::framework::Channel; using paddle::framework::Channel;
using paddle::framework::MakeChannel; using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel; using paddle::framework::CloseChannel;
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
TEST(Channel, MakeAndClose) { TEST(Channel, MakeAndClose) {
using paddle::framework::details::Buffered; using paddle::framework::details::Buffered;
...@@ -60,13 +62,54 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) { ...@@ -60,13 +62,54 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) {
delete ch; delete ch;
} }
TEST(Channel, SendOnClosedChannelPanics) { // This tests that a channel must return false
const size_t buffer_size = 10; // on send and receive performed after closing the channel.
auto ch = MakeChannel<size_t>(buffer_size); // Receive will only return false after close when queue is empty.
size_t i = 5; // By creating separate threads for sending and receiving, we make this
EXPECT_EQ(ch->Send(&i), true); // should not block or panic // function able to test both buffered and unbuffered channels.
void SendReceiveWithACloseChannelShouldPanic(Channel<size_t> *ch) {
const size_t data = 5;
std::thread send_thread{[&]() {
size_t i = data;
EXPECT_EQ(ch->Send(&i), true); // should not block
}};
std::thread recv_thread{[&]() {
size_t i;
EXPECT_EQ(ch->Receive(&i), true); // should not block
EXPECT_EQ(i, data);
}};
send_thread.join();
recv_thread.join();
// After closing send should return false. Receive should
// also return false as there is no data in queue.
CloseChannel(ch); CloseChannel(ch);
EXPECT_EQ(ch->Send(&i), false); // should panic send_thread = std::thread{[&]() {
size_t i = data;
EXPECT_EQ(ch->Send(&i), false); // should return false
}};
recv_thread = std::thread{[&]() {
size_t i;
// should return false because channel is closed and queue is empty
EXPECT_EQ(ch->Receive(&i), false);
}};
send_thread.join();
recv_thread.join();
}
TEST(Channel, SendReceiveClosedBufferedChannelPanics) {
size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
SendReceiveWithACloseChannelShouldPanic(ch);
delete ch;
}
TEST(Channel, SendReceiveClosedUnBufferedChannelPanics) {
auto ch = MakeChannel<size_t>(0);
SendReceiveWithACloseChannelShouldPanic(ch);
delete ch; delete ch;
} }
...@@ -381,3 +424,129 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) { ...@@ -381,3 +424,129 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) {
EXPECT_EQ(sum_receive, 28U); EXPECT_EQ(sum_receive, 28U);
delete ch; delete ch;
} }
// This tests that destroying a channel unblocks
// any senders waiting for channel to have write space
void ChannelDestroyUnblockSenders(Channel<int> *ch) {
size_t num_threads = 5;
std::thread t[num_threads];
bool thread_ended[num_threads];
bool send_success[num_threads];
// Launches threads that try to write and are blocked because of no readers
for (size_t i = 0; i < num_threads; i++) {
thread_ended[i] = false;
send_success[i] = false;
t[i] = std::thread(
[&](bool *ended, bool *success) {
int data = 10;
*success = ch->Send(&data);
*ended = true;
},
&thread_ended[i], &send_success[i]);
}
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec
bool is_buffered_channel = false;
if (dynamic_cast<Buffered<int> *>(ch)) is_buffered_channel = true;
if (is_buffered_channel) {
// If channel is buffered, verify that atleast 4 threads are blocked
int ct = 0;
for (size_t i = 0; i < num_threads; i++) {
if (thread_ended[i] == false) ct++;
}
// Atleast 4 threads must be blocked
EXPECT_GE(ct, 4);
} else {
// Verify that all the threads are blocked
for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], false);
}
}
// Explicitly destroy the channel
delete ch;
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], true);
}
// Count number of successfuld sends
int ct = 0;
for (size_t i = 0; i < num_threads; i++) {
if (send_success[i]) ct++;
}
if (is_buffered_channel) {
// Only 1 send must be successful
EXPECT_EQ(ct, 1);
} else {
// In unbuffered channel, no send should be successful
EXPECT_EQ(ct, 0);
}
// Join all threads
for (size_t i = 0; i < num_threads; i++) t[i].join();
}
// This tests that destroying a channel also unblocks
// any receivers waiting on the channel
void ChannelDestroyUnblockReceivers(Channel<int> *ch) {
size_t num_threads = 5;
std::thread t[num_threads];
bool thread_ended[num_threads];
// Launches threads that try to read and are blocked because of no writers
for (size_t i = 0; i < num_threads; i++) {
thread_ended[i] = false;
t[i] = std::thread(
[&](bool *p) {
int data;
// All reads should return false
EXPECT_EQ(ch->Receive(&data), false);
*p = true;
},
&thread_ended[i]);
}
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
// Verify that all threads are blocked
for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], false);
}
// delete the channel
delete ch;
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], true);
}
for (size_t i = 0; i < num_threads; i++) t[i].join();
}
TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) {
size_t buffer_size = 1;
auto ch = MakeChannel<int>(buffer_size);
ChannelDestroyUnblockReceivers(ch);
}
TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) {
size_t buffer_size = 1;
auto ch = MakeChannel<int>(buffer_size);
ChannelDestroyUnblockSenders(ch);
}
// This tests that destroying an unbuffered channel also unblocks
// unblocks any receivers waiting for senders
TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) {
auto ch = MakeChannel<int>(0);
ChannelDestroyUnblockReceivers(ch);
}
TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) {
auto ch = MakeChannel<int>(0);
ChannelDestroyUnblockSenders(ch);
}
...@@ -42,8 +42,11 @@ class Buffered : public paddle::framework::Channel<T> { ...@@ -42,8 +42,11 @@ class Buffered : public paddle::framework::Channel<T> {
std::mutex mu_; std::mutex mu_;
std::condition_variable empty_cond_var_; std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_; std::condition_variable full_cond_var_;
std::condition_variable destructor_cond_var_;
std::deque<T> channel_; std::deque<T> channel_;
std::atomic<bool> closed_{false}; std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
Buffered(size_t cap) : cap_(cap), closed_(false) { Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0); PADDLE_ENFORCE_GT(cap, 0);
...@@ -58,6 +61,7 @@ bool Buffered<T>::Send(T* item) { ...@@ -58,6 +61,7 @@ bool Buffered<T>::Send(T* item) {
if (closed_) { if (closed_) {
return ret; return ret;
} }
send_ctr++;
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; }); [this]() { return channel_.size() < cap_ || closed_; });
...@@ -67,20 +71,30 @@ bool Buffered<T>::Send(T* item) { ...@@ -67,20 +71,30 @@ bool Buffered<T>::Send(T* item) {
empty_cond_var_.notify_one(); empty_cond_var_.notify_one();
ret = true; ret = true;
} }
send_ctr--;
destructor_cond_var_.notify_one();
return ret; return ret;
} }
template <typename T> template <typename T>
bool Buffered<T>::Receive(T* item) { bool Buffered<T>::Receive(T* item) {
bool ret = false;
// Once the channel has been closed and all data has been consumed,
// just return false. Don't even try acquiring the mutex.
if (closed_ && channel_.empty()) {
return false;
}
recv_ctr++;
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; }); empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
bool ret = false;
if (!channel_.empty()) { if (!channel_.empty()) {
*item = std::move(channel_.front()); *item = std::move(channel_.front());
channel_.pop_front(); channel_.pop_front();
full_cond_var_.notify_one(); full_cond_var_.notify_one();
ret = true; ret = true;
} }
recv_ctr--;
destructor_cond_var_.notify_one();
return ret; return ret;
} }
...@@ -100,6 +114,12 @@ Buffered<T>::~Buffered() { ...@@ -100,6 +114,12 @@ Buffered<T>::~Buffered() {
closed_ = true; closed_ = true;
channel_.clear(); channel_.clear();
NotifyAllParticipants(&lock); NotifyAllParticipants(&lock);
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
lock.lock();
destructor_cond_var_.wait(
lock, [this]() { return send_ctr == 0 && recv_ctr == 0; });
} }
template <typename T> template <typename T>
......
...@@ -45,9 +45,11 @@ class UnBuffered : public paddle::framework::Channel<T> { ...@@ -45,9 +45,11 @@ class UnBuffered : public paddle::framework::Channel<T> {
// A transaction occurs only when both are true // A transaction occurs only when both are true
std::atomic<bool> reader_found_{false}, writer_found_{false}; std::atomic<bool> reader_found_{false}, writer_found_{false};
std::condition_variable cv_channel_; std::condition_variable cv_channel_;
std::condition_variable_any cv_reader_, cv_writer_; std::condition_variable_any cv_reader_, cv_writer_, cv_destructor_;
T* item{nullptr}; T* item{nullptr};
std::atomic<bool> closed_{false}; std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
UnBuffered() : closed_(false) {} UnBuffered() : closed_(false) {}
...@@ -62,6 +64,7 @@ bool UnBuffered<T>::Send(T* data) { ...@@ -62,6 +64,7 @@ bool UnBuffered<T>::Send(T* data) {
if (closed_) { if (closed_) {
return ret; return ret;
} }
send_ctr++;
// Prevent other writers from entering // Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_); std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true; writer_found_ = true;
...@@ -81,6 +84,8 @@ bool UnBuffered<T>::Send(T* data) { ...@@ -81,6 +84,8 @@ bool UnBuffered<T>::Send(T* data) {
ret = true; ret = true;
} }
writer_found_ = false; writer_found_ = false;
send_ctr--;
cv_destructor_.notify_one();
return ret; return ret;
} }
...@@ -88,6 +93,12 @@ bool UnBuffered<T>::Send(T* data) { ...@@ -88,6 +93,12 @@ bool UnBuffered<T>::Send(T* data) {
// data that was sent by a writer is read from a reader. // data that was sent by a writer is read from a reader.
template <typename T> template <typename T>
bool UnBuffered<T>::Receive(T* data) { bool UnBuffered<T>::Receive(T* data) {
bool ret = false;
// If channel is closed, we don't even want any reader to enter.
// Unlike a buffered channel, an unbuffered channel does not allow
// readers to read after closing because there is no buffer to be consumed.
if (closed_) return ret;
recv_ctr++;
// Prevent other readers from entering // Prevent other readers from entering
std::unique_lock<std::recursive_mutex> read_lock{mu_read_}; std::unique_lock<std::recursive_mutex> read_lock{mu_read_};
reader_found_ = true; reader_found_ = true;
...@@ -96,7 +107,6 @@ bool UnBuffered<T>::Receive(T* data) { ...@@ -96,7 +107,6 @@ bool UnBuffered<T>::Receive(T* data) {
cv_reader_.wait(cv_lock, cv_reader_.wait(cv_lock,
[this]() { return writer_found_ == true || closed_; }); [this]() { return writer_found_ == true || closed_; });
cv_writer_.notify_one(); cv_writer_.notify_one();
bool ret = false;
if (!closed_) { if (!closed_) {
std::unique_lock<std::mutex> lock_ch{mu_ch_}; std::unique_lock<std::mutex> lock_ch{mu_ch_};
// Reader should wait for the writer to first write its data // Reader should wait for the writer to first write its data
...@@ -110,6 +120,8 @@ bool UnBuffered<T>::Receive(T* data) { ...@@ -110,6 +120,8 @@ bool UnBuffered<T>::Receive(T* data) {
cv_channel_.notify_one(); cv_channel_.notify_one();
} }
reader_found_ = false; reader_found_ = false;
recv_ctr--;
cv_destructor_.notify_one();
return ret; return ret;
} }
...@@ -135,6 +147,9 @@ UnBuffered<T>::~UnBuffered() { ...@@ -135,6 +147,9 @@ UnBuffered<T>::~UnBuffered() {
item = nullptr; item = nullptr;
closed_ = true; closed_ = true;
NotifyAllParticipants(&lock); NotifyAllParticipants(&lock);
lock.lock();
cv_destructor_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
} }
// This function notifies all the readers, writers and // This function notifies all the readers, writers and
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册