提交 4e517881 编写于 作者: F fengjiayi

remove HasNext

上级 b825c792
......@@ -20,9 +20,8 @@ class ReaderBase {
PADDLE_ENFORCE(!shapes_.empty());
}
// Read the next batch of data. (A 'batch' can be only one instance)
// If the next batch doesn't exist, the 'out' will be an empty std::vector.
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
// Show whether the next bacth exists.
virtual bool HasNext() const = 0;
// Reinitialize the reader and read the file from the begin.
virtual void ReInit() = 0;
......
......@@ -26,7 +26,6 @@ class ReaderBase {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual bool HasNext() const = 0;
virtual void ReInit() = 0;
......@@ -52,8 +51,6 @@ class DecoratedReader : public ReaderBase {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
bool HasNext() const override { return reader_->HasNext(); }
void ReInit() override { reader_->ReInit(); }
protected:
......@@ -69,7 +66,6 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); }
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); }
void ReInit() { reader_->ReInit(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
......
......@@ -60,15 +60,16 @@ class ReadOp : public framework::OperatorBase {
const platform::Place& dev_place) const override {
framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader->HasNext()) {
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->ReadNext(&ins);
PADDLE_ENFORCE(
reader->HasNext(),
!ins.empty(),
"Reader can not read the next data even it has been re-initialized.");
}
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) {
auto* out =
......
......@@ -68,10 +68,10 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<framework::LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
if (buffer.back().empty()) {
buffer_.pop_back();
break;
}
}
......
......@@ -50,8 +50,6 @@ class RandomDataGenerator : public framework::FileReader {
}
}
bool HasNext() const override { return true; }
void ReInit() override { return; }
private:
......
......@@ -39,10 +39,10 @@ void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<framework::LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
if (buffer_.back().empty()) {
buffer_.pop_back();
break;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册