From 4e517881f7e4d0ca8e3dac7234485fd6870418cc Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 9 Mar 2018 14:45:48 +0800 Subject: [PATCH] remove HasNext --- doc/design/cpp_data_feeding.md | 3 +-- paddle/fluid/framework/reader.h | 4 ---- paddle/fluid/operators/read_op.cc | 11 ++++++----- .../fluid/operators/reader/create_batch_reader_op.cc | 8 ++++---- .../reader/create_random_data_generator_op.cc | 2 -- .../operators/reader/create_shuffle_reader_op.cc | 8 ++++---- 6 files changed, 15 insertions(+), 21 deletions(-) diff --git a/doc/design/cpp_data_feeding.md b/doc/design/cpp_data_feeding.md index 40205350f99..a122af8cb90 100644 --- a/doc/design/cpp_data_feeding.md +++ b/doc/design/cpp_data_feeding.md @@ -20,9 +20,8 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } // Read the next batch of data. (A 'batch' can be only one instance) + // If the next batch doesn't exist, the 'out' will be an empty std::vector. virtual void ReadNext(std::vector* out) = 0; - // Show whether the next bacth exists. - virtual bool HasNext() const = 0; // Reinitialize the reader and read the file from the begin. virtual void ReInit() = 0; diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 27ab6e750c2..1be3f4ef1f4 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -26,7 +26,6 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } virtual void ReadNext(std::vector* out) = 0; - virtual bool HasNext() const = 0; virtual void ReInit() = 0; @@ -52,8 +51,6 @@ class DecoratedReader : public ReaderBase { PADDLE_ENFORCE_NOT_NULL(reader_); } - bool HasNext() const override { return reader_->HasNext(); } - void ReInit() override { reader_->ReInit(); } protected: @@ -69,7 +66,6 @@ class ReaderHolder { ReaderBase* Get() const { return reader_.get(); } void ReadNext(std::vector* out) { reader_->ReadNext(out); } - bool HasNext() const { return reader_->HasNext(); } void ReInit() { reader_->ReInit(); } DDim shape(size_t idx) const { return reader_->shape(idx); } diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 62beab82d4f..2a5605e0d37 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -60,15 +60,16 @@ class ReadOp : public framework::OperatorBase { const platform::Place& dev_place) const override { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable(); - if (!reader->HasNext()) { + std::vector out_arg_names = Outputs("Out"); + std::vector ins; + reader->ReadNext(&ins); + if (ins.empty()) { reader->ReInit(); + reader->ReadNext(&ins); PADDLE_ENFORCE( - reader->HasNext(), + !ins.empty(), "Reader can not read the next data even it has been re-initialized."); } - std::vector out_arg_names = Outputs("Out"); - std::vector ins; - reader->ReadNext(&ins); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); for (size_t i = 0; i < ins.size(); ++i) { auto* out = diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index bac043a5529..9559159e829 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -68,10 +68,10 @@ void BatchReader::ReadNext(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - } else { + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); + if (buffer.back().empty()) { + buffer_.pop_back(); break; } } diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index f77ab8ab196..73c39b5da44 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -50,8 +50,6 @@ class RandomDataGenerator : public framework::FileReader { } } - bool HasNext() const override { return true; } - void ReInit() override { return; } private: diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 3e8b463efc9..4dac3831109 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -39,10 +39,10 @@ void ShuffleReader::ReadNext(std::vector* out) { buffer_.clear(); buffer_.reserve(buffer_size_); for (int i = 0; i < buffer_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - } else { + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); + if (buffer_.back().empty()) { + buffer_.pop_back(); break; } } -- GitLab