diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 429313a339fe022d3bacc5da1fbc1ddb1f35241f..4d16b82e677f9b6a748db227b4ae380e6d09c17c 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -20,8 +20,11 @@ namespace reader { class BatchReader : public framework::DecoratedReader { public: - BatchReader(const std::shared_ptr& reader, int batch_size) - : DecoratedReader(reader), batch_size_(batch_size) { + BatchReader(const std::shared_ptr& reader, int batch_size, + bool discard_leftover) + : DecoratedReader(reader), + batch_size_(batch_size), + discard_leftover_(discard_leftover) { buffer_.reserve(batch_size_); Start(); } @@ -30,6 +33,7 @@ class BatchReader : public framework::DecoratedReader { private: int batch_size_; + bool discard_leftover_; std::vector> buffer_; }; @@ -47,8 +51,8 @@ class CreateBatchReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new BatchReader(underlying_reader.Get(), Attr("batch_size"))); + out->Reset(new BatchReader(underlying_reader.Get(), Attr("batch_size"), + Attr("discard_leftover"))); } }; @@ -58,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { AddAttr("batch_size", "How many instances the batch reader yields each time.") .GreaterThan(0); + AddAttr("discard_leftover", + "If true, the leftover instances that are not enough for a " + "new batch will be discarded.") + .SetDefault(true); AddComment(R"DOC( CreateBatchReader Operator @@ -78,6 +86,9 @@ void BatchReader::ReadNextImpl(std::vector* out) { break; } } + if (discard_leftover_ && buffer_.size() < batch_size_) { + buffer_.clear(); + } // Concat instances out->clear(); if (buffer_.empty()) { diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 88c34187ea76305a0c0965d23acc30bc7554df24..efca6fe225bf4fd34d5e41e477e444541df58084 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -23,13 +23,13 @@ namespace reader { // 'Double buffer' means we shall maintain two batches of input data at the same // time. So the kCacheSize shoul be at least 2. -static constexpr size_t kCacheSize = 5; +static constexpr size_t kCacheSize = 3; // There will be two bacthes out of the channel during training: // 1. the one waiting to be sent to the channel // 2. the one just be received from the channel, which is also being used by // subsequent operators. // So the channel size should be kChacheSize - 2 -static constexpr size_t kChannelSize = 3; // kCacheSize - 2 +static constexpr size_t kChannelSize = 1; // kCacheSize - 2 class DoubleBufferReader : public framework::DecoratedReader { public: diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 7c8aa975a15ef9ff33554cc9452b037c45f7658e..82331cb272694b6f528732b1acd68b7ee614afe5 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -28,13 +28,11 @@ class MultiPassReader : public framework::DecoratedReader { void ReadNextImpl(std::vector* out) override { reader_->ReadNext(out); - if (out->empty()) { + if (out->empty() && pass_count_ < pass_num_ - 1) { + reader_->Shutdown(); + reader_->Start(); + reader_->ReadNext(out); ++pass_count_; - if (pass_count_ < pass_num_) { - reader_->Shutdown(); - reader_->Start(); - reader_->ReadNext(out); - } } }