diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc index 6416c04f36e49a72c17af8fb6a91420f788f7fcb..df9e15e22b890347a03d6816e8549c99b010bb38 100644 --- a/paddle/framework/channel_test.cc +++ b/paddle/framework/channel_test.cc @@ -60,6 +60,16 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) { delete ch; } +TEST(Channel, SendOnClosedChannelPanics) { + const size_t buffer_size = 10; + auto ch = MakeChannel(buffer_size); + size_t i = 5; + EXPECT_EQ(ch->Send(&i), true); // should not block or panic + CloseChannel(ch); + EXPECT_EQ(ch->Send(&i), false); // should panic + delete ch; +} + TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { const size_t buffer_size = 10; auto ch = MakeChannel(buffer_size); @@ -88,7 +98,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { // Note: we cannot check EXPECT_EQ(out, 0), because C++ doesn't // define zero values like Go does. } - delete ch; } diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h index b9761eab9b52780d383ecf649a58c5e9152a9765..00b63da4da7844b41168c03f55e2faa84ff44154 100644 --- a/paddle/framework/details/buffered_channel.h +++ b/paddle/framework/details/buffered_channel.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include @@ -42,7 +43,7 @@ class Buffered : public paddle::framework::Channel { std::condition_variable empty_cond_var_; std::condition_variable full_cond_var_; std::deque channel_; - bool closed_; + std::atomic closed_{false}; Buffered(size_t cap) : cap_(cap), closed_(false) { PADDLE_ENFORCE_GT(cap, 0); @@ -53,10 +54,13 @@ class Buffered : public paddle::framework::Channel { template bool Buffered::Send(T* item) { + bool ret = false; + if (closed_) { + return ret; + } std::unique_lock lock(mu_); full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_ || closed_; }); - bool ret = false; if (!closed_) { channel_.push_back(std::move(*item)); lock.unlock(); @@ -82,6 +86,9 @@ bool Buffered::Receive(T* item) { template void Buffered::Close() { + if (closed_) { + return; + } std::unique_lock lock(mu_); closed_ = true; NotifyAllParticipants(&lock); diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h index f86a894bb4a42e45edf6964e30620b68183faaa8..815cebad2d8c08aa31bb566bc6c51250870383d8 100644 --- a/paddle/framework/details/unbuffered_channel.h +++ b/paddle/framework/details/unbuffered_channel.h @@ -58,6 +58,10 @@ class UnBuffered : public paddle::framework::Channel { // be sent from a writer to a reader. template bool UnBuffered::Send(T* data) { + bool ret = false; + if (closed_) { + return ret; + } // Prevent other writers from entering std::unique_lock writer_lock(mu_write_); writer_found_ = true; @@ -66,7 +70,6 @@ bool UnBuffered::Send(T* data) { cv_writer_.wait(cv_lock, [this]() { return reader_found_ == true || closed_; }); cv_reader_.notify_one(); - bool ret = false; if (!closed_) { std::unique_lock channel_lock(mu_ch_); item = data; @@ -114,6 +117,9 @@ bool UnBuffered::Receive(T* data) { // that take place once the channel is closed. template void UnBuffered::Close() { + if (closed_) { + return; + } std::unique_lock lock(mu_ch_); item = nullptr; closed_ = true;