From 35e1e0d521c61c60893631ece2f0cd83635aec86 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 9 Mar 2018 16:23:16 +0800 Subject: [PATCH] uses channel to replace the traditional buffer --- .../reader/create_double_buffer_reader_op.cc | 77 ++++++++----------- 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 435689fcdb..b6a0609a1e 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -12,42 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include #include +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { namespace operators { namespace reader { -static constexpr size_t kDoubleBufferSize = 3; +static constexpr size_t kDoubleBufferSize = 2; class DoubleBufferReader : public framework::DecoratedReader { public: explicit DoubleBufferReader(ReaderBase* reader) : DecoratedReader(reader), - buffer_(kDoubleBufferSize), - write_pos_(0), - read_pos_(0) { - std::thread prefetch( - std::bind(&DoubleBufferReader::PrefetchThreadFunc, this)); + buffer_(framework::MakeChannel>( + kDoubleBufferSize)) { + std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); prefetch.detach(); } void ReadNext(std::vector* out) override; - bool HasNext() const override; + void ReInit() override; + + ~DoubleBufferReader() { buffer_->Close(); } private: void PrefetchThreadFunc(); - std::vector> buffer_; - size_t write_pos_; - size_t read_pos_; - - std::mutex mtx_; - std::condition_variable buffer_not_full_; - std::condition_variable buffer_not_empty_; + framework::Channel>* buffer_; }; class CreateDoubleBufferReaderOp : public framework::OperatorBase { @@ -80,44 +73,36 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { }; void DoubleBufferReader::ReadNext(std::vector* out) { - std::unique_lock lck(mtx_); - while (write_pos_ == read_pos_) { - buffer_not_empty_.wait(lck); - } - out->clear(); - out->reserve(buffer_[read_pos_].size()); - // TODO(fengjiayi): This copy shall be reduced. - for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) { - framework::LoDTensor dst; - TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst); - dst.set_lod(buffer_[read_pos_][i].lod()); - out->push_back(dst); - } - - ++read_pos_; - if (read_pos_ >= kDoubleBufferSize) { - read_pos_ = 0; - } - buffer_not_full_.notify_all(); + buffer_->Receive(out); } -bool DoubleBufferReader::HasNext() const { - return reader_->HasNext() || !buffer_.empty(); +void DoubleBufferReader::ReInit() { + reader_->ReInit(); + buffer_->Close(); + // The existing prefetch thread will terminate for the buffer_ is closed. + buffer_ = framework::MakeChannel>( + kDoubleBufferSize); + std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); + prefetch.detach(); } void DoubleBufferReader::PrefetchThreadFunc() { - while (reader_->HasNext()) { - std::unique_lock lck(mtx_); - while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) { - buffer_not_full_.wait(lck); + VLOG(5) << "A new prefetch thread starts."; + while (true) { + std::vector batch; + reader_->ReadNext(&batch); + if (batch.empty()) { + // EOF + buffer_->Close(); + VLOG(5) << "Reached the end of the file. The prefetch thread terminates."; + break; } - reader_->ReadNext(&buffer_[write_pos_]); - ++write_pos_; - if (write_pos_ >= kDoubleBufferSize) { - write_pos_ = 0; + if (!buffer_->Send(&batch)) { + VLOG(5) << "WARNING: The double buffer channel has been closed. The " + "prefetch thread terminates."; + break; } - buffer_not_empty_.notify_all(); } } -- GitLab