diff --git a/doc/design/cpp_data_feeding.md b/doc/design/cpp_data_feeding.md index 40205350f99722f0b71bfa6f390fe9d01d831966..a122af8cb9002cbf126c3b2a22ebdfc9a78f9e93 100644 --- a/doc/design/cpp_data_feeding.md +++ b/doc/design/cpp_data_feeding.md @@ -20,9 +20,8 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } // Read the next batch of data. (A 'batch' can be only one instance) + // If the next batch doesn't exist, the 'out' will be an empty std::vector. virtual void ReadNext(std::vector* out) = 0; - // Show whether the next bacth exists. - virtual bool HasNext() const = 0; // Reinitialize the reader and read the file from the begin. virtual void ReInit() = 0; diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 27ab6e750c2e665fa5055a3ecfb2f315cb4000c0..1be3f4ef1f46bd8d72a285afa69b52d6f519ccf5 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -26,7 +26,6 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } virtual void ReadNext(std::vector* out) = 0; - virtual bool HasNext() const = 0; virtual void ReInit() = 0; @@ -52,8 +51,6 @@ class DecoratedReader : public ReaderBase { PADDLE_ENFORCE_NOT_NULL(reader_); } - bool HasNext() const override { return reader_->HasNext(); } - void ReInit() override { reader_->ReInit(); } protected: @@ -69,7 +66,6 @@ class ReaderHolder { ReaderBase* Get() const { return reader_.get(); } void ReadNext(std::vector* out) { reader_->ReadNext(out); } - bool HasNext() const { return reader_->HasNext(); } void ReInit() { reader_->ReInit(); } DDim shape(size_t idx) const { return reader_->shape(idx); } diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 62beab82d4f2b0b795d5d32f50352172de6870cc..2a5605e0d378a184ae132e657b2872279784855d 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -60,15 +60,16 @@ class ReadOp : public framework::OperatorBase { const platform::Place& dev_place) const override { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable(); - if (!reader->HasNext()) { + std::vector out_arg_names = Outputs("Out"); + std::vector ins; + reader->ReadNext(&ins); + if (ins.empty()) { reader->ReInit(); + reader->ReadNext(&ins); PADDLE_ENFORCE( - reader->HasNext(), + !ins.empty(), "Reader can not read the next data even it has been re-initialized."); } - std::vector out_arg_names = Outputs("Out"); - std::vector ins; - reader->ReadNext(&ins); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); for (size_t i = 0; i < ins.size(); ++i) { auto* out = diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index bac043a5529d877dba79c03f07b9d43c9b71d7aa..9559159e8298973d2c24d3ee820a8b2b42d80275 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -68,10 +68,10 @@ void BatchReader::ReadNext(std::vector* out) { buffer_.clear(); buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - } else { + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); + if (buffer.back().empty()) { + buffer_.pop_back(); break; } } diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index f77ab8ab196dae4cf9351cee9bc5566ec2c04e4b..73c39b5da4484b27f75aeba3c8171c5ffed2398f 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -50,8 +50,6 @@ class RandomDataGenerator : public framework::FileReader { } } - bool HasNext() const override { return true; } - void ReInit() override { return; } private: diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 3e8b463efc99e4a962e5ae14ab133cf634548756..4dac3831109beeed660d32f08fb27c7adf62ac2b 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -39,10 +39,10 @@ void ShuffleReader::ReadNext(std::vector* out) { buffer_.clear(); buffer_.reserve(buffer_size_); for (int i = 0; i < buffer_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - } else { + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); + if (buffer_.back().empty()) { + buffer_.pop_back(); break; } }