diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc index 020f806380626d2f1efac683741ee84f1b573aeb..31ac72eda98859327f9857c18287398d0f459c7b 100644 --- a/paddle/framework/channel_test.cc +++ b/paddle/framework/channel_test.cc @@ -29,16 +29,16 @@ TEST(Channel, MakeAndClose) { { // MakeChannel should return a buffered channel is buffer_size > 0. auto ch = MakeChannel(10); - EXPECT_NE(dynamic_cast*>(ch), nullptr); - EXPECT_EQ(dynamic_cast*>(ch), nullptr); + EXPECT_NE(dynamic_cast *>(ch), nullptr); + EXPECT_EQ(dynamic_cast *>(ch), nullptr); CloseChannel(ch); delete ch; } { // MakeChannel should return an un-buffered channel is buffer_size = 0. auto ch = MakeChannel(0); - EXPECT_EQ(dynamic_cast*>(ch), nullptr); - EXPECT_NE(dynamic_cast*>(ch), nullptr); + EXPECT_EQ(dynamic_cast *>(ch), nullptr); + EXPECT_NE(dynamic_cast *>(ch), nullptr); CloseChannel(ch); delete ch; } @@ -100,6 +100,88 @@ TEST(Channel, SimpleUnbufferedChannelTest) { delete ch; } +// This tests that closing an unbuffered channel also unblocks +// unblocks any receivers waiting for senders +TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { + auto ch = MakeChannel(0); + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to read and are blocked becausew of no writers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data; + ch->Receive(&data); + *p = true; + }, + &thread_ended[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all the threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + + // Explicitly close the thread + // This should unblock all receivers + CloseChannel(ch); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); + delete ch; +} + +// This tests that closing an unbuffered channel also unblocks +// unblocks any senders waiting for senders +TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) { + auto ch = MakeChannel(0); + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to read and are blocked becausew of no writers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data = 10; + ch->Send(&data); + *p = true; + }, + &thread_ended[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all the threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + + // Explicitly close the thread + // This should unblock all receivers + CloseChannel(ch); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); + delete ch; +} + TEST(Channel, UnbufferedLessReceiveMoreSendTest) { auto ch = MakeChannel(0); unsigned sum_send = 0;