diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc index df9e15e22b890347a03d6816e8549c99b010bb38..3b8150b4276f61b7559ec3e3437a8a1eca02616b 100644 --- a/paddle/framework/channel_test.cc +++ b/paddle/framework/channel_test.cc @@ -115,7 +115,7 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { sum += i; } }); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec EXPECT_EQ(sum, 45U); CloseChannel(ch); @@ -144,38 +144,34 @@ TEST(Channel, SimpleUnbufferedChannelTest) { delete ch; } -// This tests that closing a buffered channel also unblocks -// any receivers waiting on the channel -TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) { - auto ch = MakeChannel(1); +void ChannelCloseUnblocksReceiversTest(Channel *ch) { size_t num_threads = 5; std::thread t[num_threads]; bool thread_ended[num_threads]; - // Launches threads that try to read and are blocked because of no writers + // Launches threads that try to read and are blocked becausew of no writers for (size_t i = 0; i < num_threads; i++) { thread_ended[i] = false; t[i] = std::thread( [&](bool *p) { int data; - // All reads should return false EXPECT_EQ(ch->Receive(&data), false); *p = true; }, &thread_ended[i]); } - std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec - // Verify that all threads are blocked + // Verify that all the threads are blocked for (size_t i = 0; i < num_threads; i++) { EXPECT_EQ(thread_ended[i], false); } - // Explicitly close the channel + // Explicitly close the thread // This should unblock all receivers CloseChannel(ch); - std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec // Verify that all threads got unblocked for (size_t i = 0; i < num_threads; i++) { @@ -183,13 +179,12 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) { } for (size_t i = 0; i < num_threads; i++) t[i].join(); - delete ch; } -// This tests that closing a buffered channel also unblocks -// any senders waiting for channel to have write space -TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { - auto ch = MakeChannel(1); +void ChannelCloseUnblocksSendersTest(Channel *ch) { + using paddle::framework::details::Buffered; + using paddle::framework::details::UnBuffered; + size_t num_threads = 5; std::thread t[num_threads]; bool thread_ended[num_threads]; @@ -209,34 +204,56 @@ TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { } std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait - // Verify that atleast 4 threads are blocked - int ct = 0; - for (size_t i = 0; i < num_threads; i++) { - if (thread_ended[i] == false) ct++; + if (dynamic_cast *>(ch)) { + // If ch is Buffered, atleast 4 threads must be blocked. + int ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (!thread_ended[i]) ct++; + } + EXPECT_GE(ct, 4); + } else { + // If ch is UnBuffered, all the threads should be blocked. + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } } - // Atleast 4 threads must be blocked - EXPECT_GE(ct, 4); - // Explicitly close the thread // This should unblock all senders CloseChannel(ch); - std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait // Verify that all threads got unblocked for (size_t i = 0; i < num_threads; i++) { EXPECT_EQ(thread_ended[i], true); } - // Verify that only 1 send was successful - ct = 0; - for (size_t i = 0; i < num_threads; i++) { - if (send_success[i]) ct++; + if (dynamic_cast *>(ch)) { + // Verify that only 1 send was successful + int ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (send_success[i]) ct++; + } + // Only 1 send must be successful + EXPECT_EQ(ct, 1); } - // Only 1 send must be successful - EXPECT_EQ(ct, 1); for (size_t i = 0; i < num_threads; i++) t[i].join(); +} + +// This tests that closing a buffered channel also unblocks +// any receivers waiting on the channel +TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) { + auto ch = MakeChannel(1); + ChannelCloseUnblocksReceiversTest(ch); + delete ch; +} + +// This tests that closing a buffered channel also unblocks +// any senders waiting for channel to have write space +TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { + auto ch = MakeChannel(1); + ChannelCloseUnblocksSendersTest(ch); delete ch; } @@ -244,40 +261,7 @@ TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { // unblocks any receivers waiting for senders TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { auto ch = MakeChannel(0); - size_t num_threads = 5; - std::thread t[num_threads]; - bool thread_ended[num_threads]; - - // Launches threads that try to read and are blocked becausew of no writers - for (size_t i = 0; i < num_threads; i++) { - thread_ended[i] = false; - t[i] = std::thread( - [&](bool *p) { - int data; - EXPECT_EQ(ch->Receive(&data), false); - *p = true; - }, - &thread_ended[i]); - } - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec - - // Verify that all the threads are blocked - for (size_t i = 0; i < num_threads; i++) { - EXPECT_EQ(thread_ended[i], false); - } - - // Explicitly close the thread - // This should unblock all receivers - CloseChannel(ch); - - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec - - // Verify that all threads got unblocked - for (size_t i = 0; i < num_threads; i++) { - EXPECT_EQ(thread_ended[i], true); - } - - for (size_t i = 0; i < num_threads; i++) t[i].join(); + ChannelCloseUnblocksReceiversTest(ch); delete ch; } @@ -285,40 +269,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { // unblocks any senders waiting for senders TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) { auto ch = MakeChannel(0); - size_t num_threads = 5; - std::thread t[num_threads]; - bool thread_ended[num_threads]; - - // Launches threads that try to read and are blocked becausew of no writers - for (size_t i = 0; i < num_threads; i++) { - thread_ended[i] = false; - t[i] = std::thread( - [&](bool *p) { - int data = 10; - EXPECT_EQ(ch->Send(&data), false); - *p = true; - }, - &thread_ended[i]); - } - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec - - // Verify that all the threads are blocked - for (size_t i = 0; i < num_threads; i++) { - EXPECT_EQ(thread_ended[i], false); - } - - // Explicitly close the thread - // This should unblock all receivers - CloseChannel(ch); - - std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec - - // Verify that all threads got unblocked - for (size_t i = 0; i < num_threads; i++) { - EXPECT_EQ(thread_ended[i], true); - } - - for (size_t i = 0; i < num_threads; i++) t[i].join(); + ChannelCloseUnblocksReceiversTest(ch); delete ch; } diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h index 00b63da4da7844b41168c03f55e2faa84ff44154..44bf84eb3098ea93121f4f38094f496cd67cd9f0 100644 --- a/paddle/framework/details/buffered_channel.h +++ b/paddle/framework/details/buffered_channel.h @@ -25,6 +25,13 @@ namespace paddle { namespace framework { namespace details { +// Four of the properties of Buffered Channel: +// - A send to a full channel blocks temporarily until a receive from the +// channel or the channel is closed +// - A receive from an empty channel blocks temporarily until a send to the +// channel or the channel is closed +// - A send to a closed channel returns false immediately +// - A receive from a closed channel returns false immediately template class Buffered : public paddle::framework::Channel { friend Channel* paddle::framework::MakeChannel(size_t);