提交 4e517881 编写于 作者: F fengjiayi

remove HasNext

上级 b825c792
...@@ -20,9 +20,8 @@ class ReaderBase { ...@@ -20,9 +20,8 @@ class ReaderBase {
PADDLE_ENFORCE(!shapes_.empty()); PADDLE_ENFORCE(!shapes_.empty());
} }
// Read the next batch of data. (A 'batch' can be only one instance) // Read the next batch of data. (A 'batch' can be only one instance)
// If the next batch doesn't exist, the 'out' will be an empty std::vector.
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
// Show whether the next bacth exists.
virtual bool HasNext() const = 0;
// Reinitialize the reader and read the file from the begin. // Reinitialize the reader and read the file from the begin.
virtual void ReInit() = 0; virtual void ReInit() = 0;
......
...@@ -26,7 +26,6 @@ class ReaderBase { ...@@ -26,7 +26,6 @@ class ReaderBase {
PADDLE_ENFORCE(!shapes_.empty()); PADDLE_ENFORCE(!shapes_.empty());
} }
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual bool HasNext() const = 0;
virtual void ReInit() = 0; virtual void ReInit() = 0;
...@@ -52,8 +51,6 @@ class DecoratedReader : public ReaderBase { ...@@ -52,8 +51,6 @@ class DecoratedReader : public ReaderBase {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
bool HasNext() const override { return reader_->HasNext(); }
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
protected: protected:
...@@ -69,7 +66,6 @@ class ReaderHolder { ...@@ -69,7 +66,6 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); } ReaderBase* Get() const { return reader_.get(); }
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); } void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); }
void ReInit() { reader_->ReInit(); } void ReInit() { reader_->ReInit(); }
DDim shape(size_t idx) const { return reader_->shape(idx); } DDim shape(size_t idx) const { return reader_->shape(idx); }
......
...@@ -60,15 +60,16 @@ class ReadOp : public framework::OperatorBase { ...@@ -60,15 +60,16 @@ class ReadOp : public framework::OperatorBase {
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>(); scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader->HasNext()) { std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit(); reader->ReInit();
reader->ReadNext(&ins);
PADDLE_ENFORCE( PADDLE_ENFORCE(
reader->HasNext(), !ins.empty(),
"Reader can not read the next data even it has been re-initialized."); "Reader can not read the next data even it has been re-initialized.");
} }
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
auto* out = auto* out =
......
...@@ -68,10 +68,10 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -68,10 +68,10 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
buffer_.clear(); buffer_.clear();
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) { for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) { buffer_.push_back(std::vector<framework::LoDTensor>());
buffer_.push_back(std::vector<framework::LoDTensor>()); reader_->ReadNext(&buffer_.back());
reader_->ReadNext(&buffer_.back()); if (buffer.back().empty()) {
} else { buffer_.pop_back();
break; break;
} }
} }
......
...@@ -50,8 +50,6 @@ class RandomDataGenerator : public framework::FileReader { ...@@ -50,8 +50,6 @@ class RandomDataGenerator : public framework::FileReader {
} }
} }
bool HasNext() const override { return true; }
void ReInit() override { return; } void ReInit() override { return; }
private: private:
......
...@@ -39,10 +39,10 @@ void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -39,10 +39,10 @@ void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) { for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) { buffer_.push_back(std::vector<framework::LoDTensor>());
buffer_.push_back(std::vector<framework::LoDTensor>()); reader_->ReadNext(&buffer_.back());
reader_->ReadNext(&buffer_.back()); if (buffer_.back().empty()) {
} else { buffer_.pop_back();
break; break;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册