From 6fc6cc2f4c18e5443e57035fe4d5b7368a36cd42 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sun, 8 Jul 2018 18:28:10 +0800 Subject: [PATCH] Some updates on readers 1. Shrink DoubleBufferReader's buffer size to 3. 2. Add BatchReader an option to discard leftover instances. 3. Fix a MultiPassReader bug on pass count. --- .../reader/create_batch_reader_op.cc | 19 +++++++++++++++---- .../reader/create_double_buffer_reader_op.cc | 4 ++-- .../reader/create_multi_pass_reader_op.cc | 10 ++++------ 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 429313a339..4d16b82e67 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -20,8 +20,11 @@ namespace reader { class BatchReader : public framework::DecoratedReader { public: - BatchReader(const std::shared_ptr& reader, int batch_size) - : DecoratedReader(reader), batch_size_(batch_size) { + BatchReader(const std::shared_ptr& reader, int batch_size, + bool discard_leftover) + : DecoratedReader(reader), + batch_size_(batch_size), + discard_leftover_(discard_leftover) { buffer_.reserve(batch_size_); Start(); } @@ -30,6 +33,7 @@ class BatchReader : public framework::DecoratedReader { private: int batch_size_; + bool discard_leftover_; std::vector> buffer_; }; @@ -47,8 +51,8 @@ class CreateBatchReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new BatchReader(underlying_reader.Get(), Attr("batch_size"))); + out->Reset(new BatchReader(underlying_reader.Get(), Attr("batch_size"), + Attr("discard_leftover"))); } }; @@ -58,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { AddAttr("batch_size", "How many instances the batch reader yields each time.") .GreaterThan(0); + AddAttr("discard_leftover", + "If true, the leftover instances that are not enough for a " + "new batch will be discarded.") + .SetDefault(true); AddComment(R"DOC( CreateBatchReader Operator @@ -78,6 +86,9 @@ void BatchReader::ReadNextImpl(std::vector* out) { break; } } + if (discard_leftover_ && buffer_.size() < batch_size_) { + buffer_.clear(); + } // Concat instances out->clear(); if (buffer_.empty()) { diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 88c34187ea..efca6fe225 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -23,13 +23,13 @@ namespace reader { // 'Double buffer' means we shall maintain two batches of input data at the same // time. So the kCacheSize shoul be at least 2. -static constexpr size_t kCacheSize = 5; +static constexpr size_t kCacheSize = 3; // There will be two bacthes out of the channel during training: // 1. the one waiting to be sent to the channel // 2. the one just be received from the channel, which is also being used by // subsequent operators. // So the channel size should be kChacheSize - 2 -static constexpr size_t kChannelSize = 3; // kCacheSize - 2 +static constexpr size_t kChannelSize = 1; // kCacheSize - 2 class DoubleBufferReader : public framework::DecoratedReader { public: diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 7c8aa975a1..82331cb272 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -28,13 +28,11 @@ class MultiPassReader : public framework::DecoratedReader { void ReadNextImpl(std::vector* out) override { reader_->ReadNext(out); - if (out->empty()) { + if (out->empty() && pass_count_ < pass_num_ - 1) { + reader_->Shutdown(); + reader_->Start(); + reader_->ReadNext(out); ++pass_count_; - if (pass_count_ < pass_num_) { - reader_->Shutdown(); - reader_->Start(); - reader_->ReadNext(out); - } } } -- GitLab