From 3fcd16ede3dd71b269fed8ae213d18491b65f186 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 6 Mar 2018 19:08:23 +0800 Subject: [PATCH] init double buffer --- paddle/fluid/framework/reader.cc | 41 ++++++++++++++++++++++++++++++++ paddle/fluid/framework/reader.h | 25 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index dc1caa72a4c..9cdce11d375 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -112,5 +112,46 @@ void BatchReader::ReadNext(std::vector* out) { out->push_back(out_tensor); } } + +void DoubleBufferReader::ReadNext(std::vector* out) { + std::unique_lock lck(mtx_); + while (write_pos_ == read_pos_) { + buffer_not_empty_.wait(lck); + } + + out->clear(); + out->resize(buffer_[read_pos_].size()); + // TODO(fengjiayi): This copy shall be reduced. + for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) { + TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &out[i]); + out[i].set_lod(buffer_[read_pos_][i].lod()); + } + + ++read_pos_; + if (read_pos_ >= kDoubleBufferSize) { + read_pos_ = 0; + } + buffer_not_full_.notify_all(); +} + +bool DoubleBufferReader::HasNext() { + return reader_->HasNext() || !buffer_.empty(); +} + +void DoubleBufferReader::ProducerThreadFunc() { + while (reader_->HasNext()) { + std::unique_lock lck(mtx); + while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) { + buffer_not_full_.wait(lck); + } + reader_->ReadNext(&buffer_[write_pos_]); + ++write_pos_; + if (write_pos_ >= kDoubleBufferSize) { + write_pos_ = 0; + } + buffer_not_empty_.notify_all(); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 4a5eba5fb73..917412ce9b2 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -16,10 +16,13 @@ #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/threadpool.h" namespace paddle { namespace framework { +static constexpr size_t kDoubleBufferSize = 3; + class ReaderBase { public: explicit ReaderBase(const std::vector& shapes) : shapes_(shapes) { @@ -135,6 +138,28 @@ class BatchReader : public DecoratedReader { std::vector> buffer_; }; +class DoubleBufferReader : public DecoratedReader { + public: + DoubleBufferReader(ReaderBase* reader) + : DecoratedReader(reader), buffer_(kDoubleBufferSize) { + framework::Async(std::bind(&DoubleBufferReader::ProducerThreadFunc, this)); + } + + void ReadNext(std::vector* out) override; + bool HasNext() const override; + + private: + void ProducerThreadFunc(); + + std::vector> buffer_; + size_t write_pos_; + size_t read_pos_; + + std::mutex mtx_; + std::condition_variable buffer_not_full_; + std::condition_variable buffer_not_empty_; +}; + // The ReaderHolder is used as readers' unified wrapper, // making it easier to access different type readers in Variables. class ReaderHolder { -- GitLab