未验证 提交 5f9da86b 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix the order of reads and write from buffered channel (#9423)

* Fix Issue 9388

* Fix typos
上级 9bbd7534
...@@ -87,6 +87,21 @@ class ChannelImpl : public paddle::framework::Channel<T> { ...@@ -87,6 +87,21 @@ class ChannelImpl : public paddle::framework::Channel<T> {
return value; return value;
} }
std::shared_ptr<QueueMessage> get_first_message(
std::deque<std::shared_ptr<QueueMessage>> &queue, ChannelAction action) {
while (!queue.empty()) {
// Check whether this message was added by Select
// If this was added by Select then execute the callback
// to check if you can execute this message. The callback
// can return false if some other case was executed in Select.
// In that case just discard this QueueMessage and process next.
std::shared_ptr<QueueMessage> m = queue.front();
queue.pop_front();
if (m->callback == nullptr || m->callback(action)) return m;
}
return nullptr;
}
size_t cap_; size_t cap_;
std::recursive_mutex mu_; std::recursive_mutex mu_;
bool closed_; bool closed_;
...@@ -131,36 +146,21 @@ void ChannelImpl<T>::Send(T *item) { ...@@ -131,36 +146,21 @@ void ChannelImpl<T>::Send(T *item) {
// If there is a receiver, directly pass the value we want // If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any // to send to the receiver, bypassing the channel buffer if any
if (!recvq.empty()) { if (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front(); std::shared_ptr<QueueMessage> m =
recvq.pop_front(); get_first_message(recvq, ChannelAction::SEND);
// Do the data transfer
// We will do this data transfer if either of the following if (m != nullptr) {
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool do_send = true;
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send)
*(m->data) = std::move(*item); *(m->data) = std::move(*item);
else { m->Notify();
// We cannot do the data transfer because lock.unlock();
// this QueueMessage was added by Select send_return();
// and some other case was executed. return;
// So call the Send function again. } else {
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
lock.unlock(); lock.unlock();
Send(item); Send(item);
send_return(); send_return();
return; return;
} }
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
send_return();
return;
} }
// Unbuffered channel will always bypass this // Unbuffered channel will always bypass this
...@@ -201,32 +201,34 @@ bool ChannelImpl<T>::Receive(T *item) { ...@@ -201,32 +201,34 @@ bool ChannelImpl<T>::Receive(T *item) {
} }
// If there is a sender, directly receive the value we want // If there is a sender, directly receive the value we want
// from the sender, bypassing the channel buffer if any // from the sender. In case of a buffered channel, read from
// buffer and move front of send queue to the buffer
if (!sendq.empty()) { if (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front(); std::shared_ptr<QueueMessage> m =
sendq.pop_front(); get_first_message(sendq, ChannelAction::RECEIVE);
// Do the data transfer if (buf_.size() > 0) {
// We will do this data transfer if either of the following // Case 1 : Channel is Buffered
// cases are true // Do Data transfer from front of buffer
// 1. callback == nullptr // This means it was a regular channel send // and add a QueueMessage to the buffer
// 2. callback returns true *item = std::move(buf_.front());
bool do_receive = true; buf_.pop_front();
if (m->callback != nullptr) // If first message from sendq is not null
do_receive = m->callback(ChannelAction::RECEIVE); // add it to the buffer and notify it
if (do_receive) if (m != nullptr) {
*item = std::move(*(m->data)); // Copy to buffer
else buf_.push_back(std::move(*(m->data)));
// We cannot do the data transfer because m->Notify();
// this QueueMessage was added by Select } // Ignore if there is no first message
// and some other case was executed. } else {
// So call the Receive function again. // Case 2: Channel is Unbuffered
// We do not care about notifying other // Do data transfer from front of SendQ
// because they would have been notified // If front is nullptr, then recursively call itself
// by the executed select case. if (m != nullptr) {
return recv_return(Receive(item)); *item = std::move(*(m->data));
m->Notify();
// Wake up the blocked process and unlock } else
m->Notify(); return recv_return(Receive(item));
}
lock.unlock(); lock.unlock();
return recv_return(true); return recv_return(true);
} }
......
...@@ -36,23 +36,25 @@ TEST(Channel, ChannelCapacityTest) { ...@@ -36,23 +36,25 @@ TEST(Channel, ChannelCapacityTest) {
delete ch; delete ch;
} }
void RecevingOrderEqualToSendingOrder(Channel<int> *ch) { void RecevingOrderEqualToSendingOrder(Channel<int> *ch, int num_items) {
unsigned sum_send = 0; unsigned sum_send = 0;
std::thread t([&]() { std::thread t([&]() {
for (int i = 0; i < 5; i++) { for (int i = 0; i < num_items; i++) {
ch->Send(&i); ch->Send(&i);
sum_send += i; sum_send += i;
} }
}); });
for (int i = 0; i < 5; i++) { std::this_thread::sleep_for(std::chrono::milliseconds(200));
int recv = 999; for (int i = 0; i < num_items; i++) {
int recv = -1;
EXPECT_EQ(ch->Receive(&recv), true); EXPECT_EQ(ch->Receive(&recv), true);
EXPECT_EQ(recv, i); EXPECT_EQ(recv, i);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(200)); std::this_thread::sleep_for(std::chrono::milliseconds(200));
CloseChannel(ch); CloseChannel(ch);
t.join(); t.join();
EXPECT_EQ(sum_send, 10U); unsigned expected_sum = (num_items * (num_items - 1)) / 2;
EXPECT_EQ(sum_send, expected_sum);
delete ch; delete ch;
} }
...@@ -185,12 +187,28 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { ...@@ -185,12 +187,28 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
TEST(Channel, RecevingOrderEqualToSendingOrderWithUnBufferedChannel) { TEST(Channel, RecevingOrderEqualToSendingOrderWithUnBufferedChannel) {
auto ch = MakeChannel<int>(0); auto ch = MakeChannel<int>(0);
RecevingOrderEqualToSendingOrder(ch); RecevingOrderEqualToSendingOrder(ch, 20);
}
TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel1) {
// Test that Receive Order is same as Send Order when number of items
// sent is less than size of buffer
auto ch = MakeChannel<int>(10);
RecevingOrderEqualToSendingOrder(ch, 5);
}
TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel2) {
// Test that Receive Order is same as Send Order when number of items
// sent is equal to size of buffer
auto ch = MakeChannel<int>(10);
RecevingOrderEqualToSendingOrder(ch, 10);
} }
TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel) { TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel3) {
// Test that Receive Order is same as Send Order when number of items
// sent is greater than the size of buffer
auto ch = MakeChannel<int>(10); auto ch = MakeChannel<int>(10);
RecevingOrderEqualToSendingOrder(ch); RecevingOrderEqualToSendingOrder(ch, 20);
} }
void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) { void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册