未验证 提交 c680bc1d 编写于 作者: Y yuyang18

Rewrite DoubleBuffer

上级 76086df4
......@@ -67,7 +67,8 @@ void ReaderBase::Start() {
}
}
ReaderBase::~ReaderBase() { Shutdown(); }
ReaderBase::~ReaderBase() {}
DecoratedReader::~DecoratedReader() { reader_->Shutdown(); }
} // namespace framework
} // namespace paddle
......@@ -25,8 +25,6 @@
namespace paddle {
namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase {
public:
virtual void ReadNext(std::vector<LoDTensor>* out);
......@@ -48,6 +46,8 @@ class ReaderBase {
virtual void StartImpl() {}
enum ReaderStatus { kRunning, kStopped };
ReaderStatus status_{kRunning};
mutable std::mutex mu_;
......@@ -74,6 +74,8 @@ class DecoratedReader : public ReaderBase,
reader_->InsertDecoratedReader(shared_from_this());
}
~DecoratedReader();
protected:
void ShutdownImpl() override { reader_->Shutdown(); }
......
......@@ -14,79 +14,79 @@
#include <thread> // NOLINT
#include "ThreadPool.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class BufferedReader : public framework::DecoratedReader {
using TensorVec = std::vector<framework::LoDTensor>;
using VecFuture = std::future<TensorVec>;
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(
const std::shared_ptr<ReaderBase>& reader,
platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize);
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
for (size_t i = 0; i < kCacheSize; ++i) {
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
}
BufferedReader(const std::shared_ptr<framework::ReaderBase>& reader,
const platform::Place& place, size_t buffer_size)
: framework::DecoratedReader(reader),
thread_pool_(1),
place_(place),
buffer_size_(buffer_size) {
AppendFutureToBatchSize();
}
~BufferedReader() override {
reader_->Shutdown();
buffer_.clear();
}
private:
void AppendFutureToBatchSize() {
while (buffer_.size() < buffer_size_) {
AppendFuture();
}
#endif
StartPrefetcher();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void AppendFuture() {
buffer_.emplace_back(thread_pool_.enqueue([this] {
TensorVec cpu_buffer;
reader_->ReadNext(&cpu_buffer);
if (platform::is_gpu_place(place_)) {
TensorVec gpu_buffer;
~DoubleBufferReader() { EndPrefetcher(); }
for (size_t i = 0; i < cpu_buffer.size(); ++i) {
gpu_buffer.emplace_back();
framework::TensorCopySync(cpu_buffer[i], place_, &gpu_buffer.back());
}
private:
cpu_buffer = gpu_buffer;
}
return cpu_buffer;
}));
}
protected:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
buffer_.clear();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
AppendFutureToBatchSize();
}
void EndPrefetcher() {
channel_->Close();
if (prefetcher_.joinable()) {
prefetcher_.join();
}
delete channel_;
channel_ = nullptr;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::cerr << "Read" << std::endl;
PADDLE_ENFORCE_EQ(buffer_.size(), buffer_size_);
*out = buffer_.front().get();
buffer_.pop_front();
AppendFuture();
}
void PrefetchThreadFunc();
std::thread prefetcher_;
reader::BlockingQueue<size_t>* channel_;
private:
ThreadPool thread_pool_;
platform::Place place_;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
const size_t buffer_size_;
std::list<VecFuture> buffer_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
......@@ -118,8 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num));
}
out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
out->Reset(framework::MakeDecoratedReader<BufferedReader>(underlying_reader,
place, 2));
}
};
......@@ -146,51 +146,6 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) {
*out = gpu_tensor_cache_[cached_tensor_id];
} else {
// CPU place
*out = cpu_tensor_cache_[cached_tensor_id];
}
} else {
out->clear();
}
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0;
while (true) {
auto& cpu_batch = cpu_tensor_cache_[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) {
// The underlying reader have no next data.
break;
}
if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache_[cached_tensor_id];
gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) {
// TODO(fengjiayi): Use asynchronous TensorCopy instead
framework::TensorCopySync(cpu_batch[i], place_, &gpu_batch[i]);
gpu_batch[i].set_lod(cpu_batch[i].lod());
}
}
if (!channel_->Send(cached_tensor_id)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate.";
break;
}
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
}
channel_->Close();
VLOG(5) << "Prefetch thread terminates.";
}
} // namespace reader
} // namespace operators
} // namespace paddle
......
......@@ -190,3 +190,7 @@ class TestDataBalance(unittest.TestCase):
def test_all(self):
self.main()
self.main_lod()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册