diff --git a/paddle/fluid/framework/channel_impl.h b/paddle/fluid/framework/channel_impl.h index 378a0bab1cc7408266fa45a0b3dc19619dd4fb4c..c47d629289af2c3d1f7c30d711d338745bf5234c 100644 --- a/paddle/fluid/framework/channel_impl.h +++ b/paddle/fluid/framework/channel_impl.h @@ -87,6 +87,21 @@ class ChannelImpl : public paddle::framework::Channel { return value; } + std::shared_ptr get_first_message( + std::deque> &queue, ChannelAction action) { + while (!queue.empty()) { + // Check whether this message was added by Select + // If this was added by Select then execute the callback + // to check if you can execute this message. The callback + // can return false if some other case was executed in Select. + // In that case just discard this QueueMessage and process next. + std::shared_ptr m = queue.front(); + queue.pop_front(); + if (m->callback == nullptr || m->callback(action)) return m; + } + return nullptr; + } + size_t cap_; std::recursive_mutex mu_; bool closed_; @@ -131,36 +146,21 @@ void ChannelImpl::Send(T *item) { // If there is a receiver, directly pass the value we want // to send to the receiver, bypassing the channel buffer if any if (!recvq.empty()) { - std::shared_ptr m = recvq.front(); - recvq.pop_front(); - // Do the data transfer - // We will do this data transfer if either of the following - // cases are true - // 1. callback == nullptr // This means it was a regular channel send - // 2. callback returns true - bool do_send = true; - if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND); - if (do_send) + std::shared_ptr m = + get_first_message(recvq, ChannelAction::SEND); + + if (m != nullptr) { *(m->data) = std::move(*item); - else { - // We cannot do the data transfer because - // this QueueMessage was added by Select - // and some other case was executed. - // So call the Send function again. - // We do not care about notifying other - // because they would have been notified - // by the executed select case. + m->Notify(); + lock.unlock(); + send_return(); + return; + } else { lock.unlock(); Send(item); send_return(); return; } - - // Wake up the blocked process and unlock - m->Notify(); - lock.unlock(); - send_return(); - return; } // Unbuffered channel will always bypass this @@ -201,32 +201,34 @@ bool ChannelImpl::Receive(T *item) { } // If there is a sender, directly receive the value we want - // from the sender, bypassing the channel buffer if any + // from the sender. In case of a buffered channel, read from + // buffer and move front of send queue to the buffer if (!sendq.empty()) { - std::shared_ptr m = sendq.front(); - sendq.pop_front(); - // Do the data transfer - // We will do this data transfer if either of the following - // cases are true - // 1. callback == nullptr // This means it was a regular channel send - // 2. callback returns true - bool do_receive = true; - if (m->callback != nullptr) - do_receive = m->callback(ChannelAction::RECEIVE); - if (do_receive) - *item = std::move(*(m->data)); - else - // We cannot do the data transfer because - // this QueueMessage was added by Select - // and some other case was executed. - // So call the Receive function again. - // We do not care about notifying other - // because they would have been notified - // by the executed select case. - return recv_return(Receive(item)); - - // Wake up the blocked process and unlock - m->Notify(); + std::shared_ptr m = + get_first_message(sendq, ChannelAction::RECEIVE); + if (buf_.size() > 0) { + // Case 1 : Channel is Buffered + // Do Data transfer from front of buffer + // and add a QueueMessage to the buffer + *item = std::move(buf_.front()); + buf_.pop_front(); + // If first message from sendq is not null + // add it to the buffer and notify it + if (m != nullptr) { + // Copy to buffer + buf_.push_back(std::move(*(m->data))); + m->Notify(); + } // Ignore if there is no first message + } else { + // Case 2: Channel is Unbuffered + // Do data transfer from front of SendQ + // If front is nullptr, then recursively call itself + if (m != nullptr) { + *item = std::move(*(m->data)); + m->Notify(); + } else + return recv_return(Receive(item)); + } lock.unlock(); return recv_return(true); } diff --git a/paddle/fluid/framework/channel_test.cc b/paddle/fluid/framework/channel_test.cc index e2380bb54bd25c4f30f79cad30f95f7cb056eef0..1184bfdae1940286fb72d9091ae4f23ff7f84a54 100644 --- a/paddle/fluid/framework/channel_test.cc +++ b/paddle/fluid/framework/channel_test.cc @@ -36,23 +36,25 @@ TEST(Channel, ChannelCapacityTest) { delete ch; } -void RecevingOrderEqualToSendingOrder(Channel *ch) { +void RecevingOrderEqualToSendingOrder(Channel *ch, int num_items) { unsigned sum_send = 0; std::thread t([&]() { - for (int i = 0; i < 5; i++) { + for (int i = 0; i < num_items; i++) { ch->Send(&i); sum_send += i; } }); - for (int i = 0; i < 5; i++) { - int recv = 999; + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + for (int i = 0; i < num_items; i++) { + int recv = -1; EXPECT_EQ(ch->Receive(&recv), true); EXPECT_EQ(recv, i); } std::this_thread::sleep_for(std::chrono::milliseconds(200)); CloseChannel(ch); t.join(); - EXPECT_EQ(sum_send, 10U); + unsigned expected_sum = (num_items * (num_items - 1)) / 2; + EXPECT_EQ(sum_send, expected_sum); delete ch; } @@ -185,12 +187,28 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { TEST(Channel, RecevingOrderEqualToSendingOrderWithUnBufferedChannel) { auto ch = MakeChannel(0); - RecevingOrderEqualToSendingOrder(ch); + RecevingOrderEqualToSendingOrder(ch, 20); +} + +TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel1) { + // Test that Receive Order is same as Send Order when number of items + // sent is less than size of buffer + auto ch = MakeChannel(10); + RecevingOrderEqualToSendingOrder(ch, 5); +} + +TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel2) { + // Test that Receive Order is same as Send Order when number of items + // sent is equal to size of buffer + auto ch = MakeChannel(10); + RecevingOrderEqualToSendingOrder(ch, 10); } -TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel) { +TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel3) { + // Test that Receive Order is same as Send Order when number of items + // sent is greater than the size of buffer auto ch = MakeChannel(10); - RecevingOrderEqualToSendingOrder(ch); + RecevingOrderEqualToSendingOrder(ch, 20); } void ChannelCloseUnblocksReceiversTest(Channel *ch) {