diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc index 2c4e622bd789c691c38a8810fe5e09e464a8cf1f..695169fcb9e93b5e69d3d4ae6f63f8e4c2d1605f 100644 --- a/paddle/fluid/framework/channel_test.cc +++ b/paddle/fluid/framework/channel_test.cc @@ -542,3 +542,341 @@ TEST(ChannelHolder, ChannelHolderUnBufferedSendReceiveTest) { ChannelHolderSendReceive(ch); delete ch; } + +TEST(ChannelHolder, ChannelUninitializedTest) { + ChannelHolder *ch = new ChannelHolder(); + EXPECT_EQ(ch->IsInitialized(), false); + int i = 10; + EXPECT_EQ(ch->Send(&i), false); + EXPECT_EQ(ch->Receive(&i), false); + bool is_exception = false; + try { + ch->Type(); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); + delete ch; +} + +TEST(ChannelHolder, ChannelInitializedTest) { + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(2); + EXPECT_EQ(ch->IsInitialized(), true); + // Channel should remain intialized even after close + ch->close(); + EXPECT_EQ(ch->IsInitialized(), true); + delete ch; +} + +TEST(ChannelHolder, TypeMismatchSendTest) { + // Test with unbuffered channel + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(0); + bool is_exception = false; + bool boolean_data = true; + try { + ch->Send(&boolean_data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); + delete ch; + + // Test with Buffered Channel + ch = new ChannelHolder(); + ch->Reset(10); + is_exception = false; + int int_data = 23; + try { + ch->Send(&int_data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); + delete ch; +} + +TEST(ChannelHolder, TypeMismatchReceiveTest) { + // Test with unbuffered channel + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(0); + bool is_exception = false; + bool float_data; + try { + ch->Receive(&float_data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); + delete ch; + + // Test with Buffered Channel + ch = new ChannelHolder(); + ch->Reset(10); + is_exception = false; + int int_data = 23; + try { + ch->Receive(&int_data); + } catch (paddle::platform::EnforceNotMet e) { + is_exception = true; + } + EXPECT_EQ(is_exception, true); + delete ch; +} + +void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to read and are blocked because of no writers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data; + EXPECT_EQ(ch->Receive(&data), false); + *p = true; + }, + &thread_ended[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec + + // Verify that all the threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + + // Explicitly close the channel + // This should unblock all receivers + ch->close(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); +} + +void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { + using paddle::framework::details::Buffered; + using paddle::framework::details::UnBuffered; + + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + bool send_success[num_threads]; + + // Launches threads that try to write and are blocked because of no readers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + send_success[i] = false; + t[i] = std::thread( + [&](bool *ended, bool *success) { + int data = 10; + *success = ch->Send(&data); + *ended = true; + }, + &thread_ended[i], &send_success[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + + if (isBuffered) { + // If ch is Buffered, atleast 4 threads must be blocked. + int ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (!thread_ended[i]) ct++; + } + EXPECT_GE(ct, 4); + } else { + // If ch is UnBuffered, all the threads should be blocked. + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + } + // Explicitly close the thread + // This should unblock all senders + ch->close(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + if (isBuffered) { + // Verify that only 1 send was successful + int ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (send_success[i]) ct++; + } + // Only 1 send must be successful + EXPECT_EQ(ct, 1); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); +} + +// This tests that closing a channelholder unblocks +// any receivers waiting on the channel +TEST(ChannelHolder, ChannelHolderCloseUnblocksReceiversTest) { + // Check for buffered channel + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(1); + ChannelHolderCloseUnblocksReceiversTest(ch); + delete ch; + + // Check for unbuffered channel + ch = new ChannelHolder(); + ch->Reset(0); + ChannelHolderCloseUnblocksReceiversTest(ch); + delete ch; +} + +// This tests that closing a channelholder unblocks +// any senders waiting for channel to have write space +TEST(Channel, ChannelHolderCloseUnblocksSendersTest) { + // Check for buffered channel + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(1); + ChannelHolderCloseUnblocksSendersTest(ch, true); + delete ch; + + // Check for unbuffered channel + ch = new ChannelHolder(); + ch->Reset(0); + ChannelHolderCloseUnblocksSendersTest(ch, false); + delete ch; +} + +// This tests that destroying a channelholder unblocks +// any senders waiting for channel +void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) { + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + bool send_success[num_threads]; + + // Launches threads that try to write and are blocked because of no readers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + send_success[i] = false; + t[i] = std::thread( + [&](bool *ended, bool *success) { + int data = 10; + *success = ch->Send(&data); + *ended = true; + }, + &thread_ended[i], &send_success[i]); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec + if (isBuffered) { + // If channel is buffered, verify that atleast 4 threads are blocked + int ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (thread_ended[i] == false) ct++; + } + // Atleast 4 threads must be blocked + EXPECT_GE(ct, 4); + } else { + // Verify that all the threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + } + // Explicitly destroy the channel + delete ch; + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait + + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + // Count number of successfuld sends + int ct = 0; + for (size_t i = 0; i < num_threads; i++) { + if (send_success[i]) ct++; + } + + if (isBuffered) { + // Only 1 send must be successful + EXPECT_EQ(ct, 1); + } else { + // In unbuffered channel, no send should be successful + EXPECT_EQ(ct, 0); + } + + // Join all threads + for (size_t i = 0; i < num_threads; i++) t[i].join(); +} + +// This tests that destroying a channelholder also unblocks +// any receivers waiting on the channel +void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) { + size_t num_threads = 5; + std::thread t[num_threads]; + bool thread_ended[num_threads]; + + // Launches threads that try to read and are blocked because of no writers + for (size_t i = 0; i < num_threads; i++) { + thread_ended[i] = false; + t[i] = std::thread( + [&](bool *p) { + int data; + // All reads should return false + EXPECT_EQ(ch->Receive(&data), false); + *p = true; + }, + &thread_ended[i]); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait + + // Verify that all threads are blocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], false); + } + // delete the channel + delete ch; + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait + // Verify that all threads got unblocked + for (size_t i = 0; i < num_threads; i++) { + EXPECT_EQ(thread_ended[i], true); + } + + for (size_t i = 0; i < num_threads; i++) t[i].join(); +} + +TEST(ChannelHolder, ChannelHolderDestroyUnblocksReceiversTest) { + // Check for Buffered Channel + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(1); + ChannelHolderDestroyUnblockReceivers(ch); + // ch is already deleted already deleted in + // ChannelHolderDestroyUnblockReceivers + + // Check for Unbuffered channel + ch = new ChannelHolder(); + ch->Reset(0); + ChannelHolderDestroyUnblockReceivers(ch); +} + +TEST(ChannelHolder, ChannelHolderDestroyUnblocksSendersTest) { + // Check for Buffered Channel + ChannelHolder *ch = new ChannelHolder(); + ch->Reset(1); + ChannelHolderDestroyUnblockSenders(ch, true); + // ch is already deleted already deleted in + // ChannelHolderDestroyUnblockReceivers + + // Check for Unbuffered channel + ch = new ChannelHolder(); + ch->Reset(0); + ChannelHolderDestroyUnblockSenders(ch, false); +}