未验证 提交 5f9da86b 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix the order of reads and write from buffered channel (#9423)

* Fix Issue 9388

* Fix typos
上级 9bbd7534
......@@ -87,6 +87,21 @@ class ChannelImpl : public paddle::framework::Channel<T> {
return value;
}
std::shared_ptr<QueueMessage> get_first_message(
std::deque<std::shared_ptr<QueueMessage>> &queue, ChannelAction action) {
while (!queue.empty()) {
// Check whether this message was added by Select
// If this was added by Select then execute the callback
// to check if you can execute this message. The callback
// can return false if some other case was executed in Select.
// In that case just discard this QueueMessage and process next.
std::shared_ptr<QueueMessage> m = queue.front();
queue.pop_front();
if (m->callback == nullptr || m->callback(action)) return m;
}
return nullptr;
}
size_t cap_;
std::recursive_mutex mu_;
bool closed_;
......@@ -131,36 +146,21 @@ void ChannelImpl<T>::Send(T *item) {
// If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any
if (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
// Do the data transfer
// We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool do_send = true;
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send)
std::shared_ptr<QueueMessage> m =
get_first_message(recvq, ChannelAction::SEND);
if (m != nullptr) {
*(m->data) = std::move(*item);
else {
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Send function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
m->Notify();
lock.unlock();
send_return();
return;
} else {
lock.unlock();
Send(item);
send_return();
return;
}
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
send_return();
return;
}
// Unbuffered channel will always bypass this
......@@ -201,32 +201,34 @@ bool ChannelImpl<T>::Receive(T *item) {
}
// If there is a sender, directly receive the value we want
// from the sender, bypassing the channel buffer if any
// from the sender. In case of a buffered channel, read from
// buffer and move front of send queue to the buffer
if (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
// Do the data transfer
// We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool do_receive = true;
if (m->callback != nullptr)
do_receive = m->callback(ChannelAction::RECEIVE);
if (do_receive)
*item = std::move(*(m->data));
else
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Receive function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return recv_return(Receive(item));
// Wake up the blocked process and unlock
m->Notify();
std::shared_ptr<QueueMessage> m =
get_first_message(sendq, ChannelAction::RECEIVE);
if (buf_.size() > 0) {
// Case 1 : Channel is Buffered
// Do Data transfer from front of buffer
// and add a QueueMessage to the buffer
*item = std::move(buf_.front());
buf_.pop_front();
// If first message from sendq is not null
// add it to the buffer and notify it
if (m != nullptr) {
// Copy to buffer
buf_.push_back(std::move(*(m->data)));
m->Notify();
} // Ignore if there is no first message
} else {
// Case 2: Channel is Unbuffered
// Do data transfer from front of SendQ
// If front is nullptr, then recursively call itself
if (m != nullptr) {
*item = std::move(*(m->data));
m->Notify();
} else
return recv_return(Receive(item));
}
lock.unlock();
return recv_return(true);
}
......
......@@ -36,23 +36,25 @@ TEST(Channel, ChannelCapacityTest) {
delete ch;
}
void RecevingOrderEqualToSendingOrder(Channel<int> *ch) {
void RecevingOrderEqualToSendingOrder(Channel<int> *ch, int num_items) {
unsigned sum_send = 0;
std::thread t([&]() {
for (int i = 0; i < 5; i++) {
for (int i = 0; i < num_items; i++) {
ch->Send(&i);
sum_send += i;
}
});
for (int i = 0; i < 5; i++) {
int recv = 999;
std::this_thread::sleep_for(std::chrono::milliseconds(200));
for (int i = 0; i < num_items; i++) {
int recv = -1;
EXPECT_EQ(ch->Receive(&recv), true);
EXPECT_EQ(recv, i);
}
std::this_thread::sleep_for(std::chrono::milliseconds(200));
CloseChannel(ch);
t.join();
EXPECT_EQ(sum_send, 10U);
unsigned expected_sum = (num_items * (num_items - 1)) / 2;
EXPECT_EQ(sum_send, expected_sum);
delete ch;
}
......@@ -185,12 +187,28 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
TEST(Channel, RecevingOrderEqualToSendingOrderWithUnBufferedChannel) {
auto ch = MakeChannel<int>(0);
RecevingOrderEqualToSendingOrder(ch);
RecevingOrderEqualToSendingOrder(ch, 20);
}
TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel1) {
// Test that Receive Order is same as Send Order when number of items
// sent is less than size of buffer
auto ch = MakeChannel<int>(10);
RecevingOrderEqualToSendingOrder(ch, 5);
}
TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel2) {
// Test that Receive Order is same as Send Order when number of items
// sent is equal to size of buffer
auto ch = MakeChannel<int>(10);
RecevingOrderEqualToSendingOrder(ch, 10);
}
TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel) {
TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel3) {
// Test that Receive Order is same as Send Order when number of items
// sent is greater than the size of buffer
auto ch = MakeChannel<int>(10);
RecevingOrderEqualToSendingOrder(ch);
RecevingOrderEqualToSendingOrder(ch, 20);
}
void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册