未验证 提交 9c7fa6ff 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #10206 from JiayiFeng/blocking_queue_for_reader

Blocking queue for reader
...@@ -23,5 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o ...@@ -23,5 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc) reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
# Export local libraries to parent # Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable> // NOLINT
#include <deque>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace reader {
template <typename T>
class BlockingQueue {
// BlockingQueue is for buffered reading and is supposed to use only the
// reader package. It is true that we could and we should have been using
// framework::Channel, but which has currently a deadlock bug. BlockingQueue
// is a workaround and a simplified version of framework::Channel as it
// doesn't support GPU and it implements on buffered blocking queue.
public:
explicit BlockingQueue(size_t capacity)
: capacity_(capacity), closed_(false) {
PADDLE_ENFORCE_GT(
capacity_, 0,
"The capacity of a reader::BlockingQueue must be greater than 0.");
}
bool Send(const T& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false;
}
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
receive_cv_.notify_one();
return true;
}
bool Send(T&& elem) {
std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
if (closed_) {
VLOG(5)
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false;
}
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
receive_cv_.notify_one();
return true;
}
bool Receive(T* elem) {
std::unique_lock<std::mutex> lock(mutex_);
receive_cv_.wait(lock, [&] { return !queue_.empty() || closed_; });
if (!queue_.empty()) {
PADDLE_ENFORCE_NOT_NULL(elem);
*elem = queue_.front();
queue_.pop_front();
send_cv_.notify_one();
return true;
} else {
PADDLE_ENFORCE(closed_);
return false;
}
}
void Close() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
send_cv_.notify_all();
receive_cv_.notify_all();
}
bool IsClosed() {
std::lock_guard<std::mutex> lock(mutex_);
return closed_;
}
size_t Cap() {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
private:
size_t capacity_;
bool closed_;
std::deque<T> queue_;
std::mutex mutex_;
std::condition_variable receive_cv_;
std::condition_variable send_cv_;
};
} // namespace reader
} // namespace operators
} // namespace paddle
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -23,13 +23,13 @@ namespace reader { ...@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2. // time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 2; static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training: // There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel // 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by // 2. the one just be received from the channel, which is also being used by
// subsequent operators. // subsequent operators.
// So the channel size should be kChacheSize - 2 // So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 0; // kCacheSize - 2 static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
...@@ -55,10 +55,8 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -55,10 +55,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
~DoubleBufferReader() { EndPrefetcher(); } ~DoubleBufferReader() { EndPrefetcher(); }
private: private:
bool HasNext() const;
void StartPrefetcher() { void StartPrefetcher() {
channel_ = framework::MakeChannel<size_t>(kChannelSize); channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
} }
...@@ -74,7 +72,7 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -74,7 +72,7 @@ class DoubleBufferReader : public framework::DecoratedReader {
void PrefetchThreadFunc(); void PrefetchThreadFunc();
std::thread prefetcher_; std::thread prefetcher_;
framework::Channel<size_t>* channel_; reader::BlockingQueue<size_t>* channel_;
platform::Place place_; platform::Place place_;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_; std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_; std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
...@@ -139,17 +137,16 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -139,17 +137,16 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); size_t cached_tensor_id;
if (HasNext()) { if (channel_->Receive(&cached_tensor_id)) {
size_t cached_tensor_id;
channel_->Receive(&cached_tensor_id);
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
*out = gpu_tensor_cache_[cached_tensor_id]; *out = gpu_tensor_cache_[cached_tensor_id];
ctxs_[cached_tensor_id]->Wait();
} else { } else {
// CPU place // CPU place
*out = cpu_tensor_cache_[cached_tensor_id]; *out = cpu_tensor_cache_[cached_tensor_id];
} }
} else {
out->clear();
} }
} }
...@@ -159,12 +156,6 @@ void DoubleBufferReader::ReInit() { ...@@ -159,12 +156,6 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher(); StartPrefetcher();
} }
bool DoubleBufferReader::HasNext() const {
while (!channel_->IsClosed() && !channel_->CanReceive()) {
}
return channel_->CanReceive();
}
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
...@@ -185,10 +176,7 @@ void DoubleBufferReader::PrefetchThreadFunc() { ...@@ -185,10 +176,7 @@ void DoubleBufferReader::PrefetchThreadFunc() {
gpu_batch[i].set_lod(cpu_batch[i].lod()); gpu_batch[i].set_lod(cpu_batch[i].lod());
} }
} }
try { if (!channel_->Send(cached_tensor_id)) {
size_t tmp = cached_tensor_id;
channel_->Send(&tmp);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The " VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate."; "prefetch thread will terminate.";
break; break;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -37,7 +37,6 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -37,7 +37,6 @@ class MultiFileReader : public framework::ReaderBase {
~MultiFileReader() { EndScheduler(); } ~MultiFileReader() { EndScheduler(); }
private: private:
bool HasNext();
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler(); void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
...@@ -48,15 +47,14 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -48,15 +47,14 @@ class MultiFileReader : public framework::ReaderBase {
std::thread scheduler_; std::thread scheduler_;
std::vector<std::thread> prefetchers_; std::vector<std::thread> prefetchers_;
size_t buffer_size_; size_t buffer_size_;
framework::Channel<size_t>* waiting_file_idx_; reader::BlockingQueue<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_; reader::BlockingQueue<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
}; };
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear(); if (!buffer_->Receive(out)) {
if (HasNext()) { out->clear();
buffer_->Receive(out);
} }
} }
...@@ -65,25 +63,19 @@ void MultiFileReader::ReInit() { ...@@ -65,25 +63,19 @@ void MultiFileReader::ReInit() {
StartNewScheduler(); StartNewScheduler();
} }
bool MultiFileReader::HasNext() {
while (!buffer_->IsClosed() && !buffer_->CanReceive()) {
}
return buffer_->CanReceive();
}
void MultiFileReader::StartNewScheduler() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size()); waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num); available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
buffer_ = buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_); buffer_size_);
for (size_t i = 0; i < file_names_.size(); ++i) { for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i); waiting_file_idx_->Send(i);
} }
waiting_file_idx_->Close(); waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) { for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i); available_thread_idx_->Send(i);
} }
scheduler_ = std::thread([this] { ScheduleThreadFunc(); }); scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
...@@ -149,7 +141,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name, ...@@ -149,7 +141,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
break; break;
} }
try { try {
buffer_->Send(&ins); buffer_->Send(std::move(ins));
} catch (paddle::platform::EnforceNotMet e) { } catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '" "thread of file '"
...@@ -158,9 +150,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name, ...@@ -158,9 +150,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
} }
} }
try { if (!available_thread_idx_->Send(thread_idx)) {
available_thread_idx_->Send(&thread_idx);
} catch (paddle::platform::EnforceNotMet e) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx."; "Fail to send thread_idx.";
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <chrono> // NOLINT
#include <set>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
using paddle::operators::reader::BlockingQueue;
TEST(BlockingQueue, CapacityTest) {
size_t cap = 10;
BlockingQueue<int> q(cap);
EXPECT_EQ(q.Cap(), cap);
}
void FirstInFirstOut(size_t queue_cap, size_t elem_num, size_t send_time_gap,
size_t receive_time_gap) {
BlockingQueue<size_t> q(queue_cap);
std::thread sender([&]() {
for (size_t i = 0; i < elem_num; ++i) {
std::this_thread::sleep_for(std::chrono::milliseconds(send_time_gap));
EXPECT_TRUE(q.Send(i));
}
q.Close();
});
size_t count = 0;
while (true) {
std::this_thread::sleep_for(std::chrono::milliseconds(receive_time_gap));
size_t elem;
if (!q.Receive(&elem)) {
break;
}
EXPECT_EQ(elem, count++);
}
sender.join();
EXPECT_EQ(count, elem_num);
EXPECT_TRUE(q.IsClosed());
}
TEST(BlockingQueue, FirstInFirstOutTest) {
FirstInFirstOut(2, 5, 2, 50);
FirstInFirstOut(2, 5, 50, 2);
FirstInFirstOut(10, 3, 50, 2);
FirstInFirstOut(10, 3, 2, 50);
}
TEST(BlockingQueue, SenderBlockingTest) {
const size_t queue_cap = 2;
BlockingQueue<size_t> q(queue_cap);
size_t send_count = 0;
std::thread sender([&]() {
for (size_t i = 0; i < 5; ++i) {
if (!q.Send(i)) {
break;
}
++send_count;
}
});
std::this_thread::sleep_for(std::chrono::milliseconds(200));
q.Close();
sender.join();
EXPECT_EQ(send_count, queue_cap);
std::vector<size_t> res;
while (true) {
size_t elem;
if (!q.Receive(&elem)) {
break;
}
res.push_back(elem);
}
EXPECT_EQ(res.size(), queue_cap);
for (size_t i = 0; i < res.size(); ++i) {
EXPECT_EQ(res[i], i);
}
}
TEST(BlockingQueue, ReceiverBlockingTest) {
const size_t queue_cap = 5;
BlockingQueue<size_t> q(queue_cap);
std::vector<size_t> receive_res;
std::thread receiver([&]() {
size_t elem;
while (true) {
if (!q.Receive(&elem)) {
break;
}
receive_res.push_back(elem);
}
});
std::vector<size_t> to_send{2, 1, 7};
for (auto e : to_send) {
q.Send(e);
}
q.Close();
receiver.join();
EXPECT_EQ(receive_res.size(), to_send.size());
for (size_t i = 0; i < to_send.size(); ++i) {
EXPECT_EQ(receive_res[i], to_send[i]);
}
}
void CheckIsUnorderedSame(const std::vector<std::vector<size_t>>& v1,
const std::vector<std::vector<size_t>>& v2) {
std::set<size_t> s1;
std::set<size_t> s2;
for (auto vec : v1) {
for (size_t elem : vec) {
s1.insert(elem);
}
}
for (auto vec : v2) {
for (size_t elem : vec) {
s2.insert(elem);
}
}
EXPECT_EQ(s1.size(), s2.size());
auto it1 = s1.begin();
auto it2 = s2.begin();
while (it1 != s1.end()) {
EXPECT_EQ(*it1, *it2);
++it1;
++it2;
}
}
void MultiSenderMultiReceiver(const size_t queue_cap,
const std::vector<std::vector<size_t>>& to_send,
size_t receiver_num, size_t send_time_gap,
size_t receive_time_gap) {
BlockingQueue<size_t> q(queue_cap);
size_t sender_num = to_send.size();
std::vector<std::thread> senders;
for (size_t s_idx = 0; s_idx < sender_num; ++s_idx) {
senders.emplace_back(std::thread([&, s_idx] {
for (size_t elem : to_send[s_idx]) {
std::this_thread::sleep_for(std::chrono::milliseconds(send_time_gap));
EXPECT_TRUE(q.Send(elem));
}
}));
}
std::vector<std::thread> receivers;
std::mutex mu;
std::vector<std::vector<size_t>> res;
for (size_t r_idx = 0; r_idx < receiver_num; ++r_idx) {
receivers.emplace_back(std::thread([&] {
std::vector<size_t> receiver_res;
while (true) {
std::this_thread::sleep_for(
std::chrono::milliseconds(receive_time_gap));
size_t elem;
if (!q.Receive(&elem)) {
break;
}
receiver_res.push_back(elem);
}
std::lock_guard<std::mutex> lock(mu);
res.push_back(receiver_res);
}));
}
for (auto& t : senders) {
t.join();
}
q.Close();
for (auto& t : receivers) {
t.join();
}
CheckIsUnorderedSame(to_send, res);
}
TEST(BlockingQueue, MultiSenderMultiReaderTest) {
std::vector<std::vector<size_t>> to_send_1{{2, 3, 4}, {9}, {0, 7, 15, 6}};
MultiSenderMultiReceiver(2, to_send_1, 2, 0, 0);
MultiSenderMultiReceiver(10, to_send_1, 2, 0, 0);
MultiSenderMultiReceiver(2, to_send_1, 20, 0, 0);
MultiSenderMultiReceiver(2, to_send_1, 2, 50, 0);
MultiSenderMultiReceiver(2, to_send_1, 2, 0, 50);
std::vector<std::vector<size_t>> to_send_2{
{2, 3, 4}, {}, {0, 7, 15, 6, 9, 32}};
MultiSenderMultiReceiver(2, to_send_2, 3, 0, 0);
MultiSenderMultiReceiver(20, to_send_2, 3, 0, 0);
MultiSenderMultiReceiver(2, to_send_2, 30, 0, 0);
MultiSenderMultiReceiver(2, to_send_2, 3, 50, 0);
MultiSenderMultiReceiver(2, to_send_2, 3, 0, 50);
}
struct MyClass {
MyClass() : val_(0) {}
explicit MyClass(int val) : val_(val) {}
MyClass(const MyClass& b) { val_ = b.val_; }
MyClass(MyClass&& b) { val_ = b.val_; }
void operator=(const MyClass& b) { val_ = b.val_; }
int val_;
};
TEST(BlockingQueue, MyClassTest) {
BlockingQueue<MyClass> q(2);
MyClass a(200);
q.Send(std::move(a));
MyClass b;
q.Receive(&b);
EXPECT_EQ(a.val_, b.val_);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册