提交 7eedced8 编写于 作者: Y Yu Yang

Polish RecordIO

上级 cfca8a3a
...@@ -33,6 +33,8 @@ class ReaderBase { ...@@ -33,6 +33,8 @@ class ReaderBase {
std::vector<DDim> shapes() const { return shapes_; } std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; } void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual bool HasNext() const = 0;
virtual ~ReaderBase() {} virtual ~ReaderBase() {}
protected: protected:
...@@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase { ...@@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); } void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected: protected:
ReaderBase* reader_; ReaderBase* reader_;
}; };
...@@ -74,6 +78,8 @@ class ReaderHolder { ...@@ -74,6 +78,8 @@ class ReaderHolder {
reader_->set_shapes(shapes); reader_->set_shapes(shapes);
} }
bool HasNext() const { return reader_->HasNext(); }
private: private:
std::unique_ptr<ReaderBase> reader_; std::unique_ptr<ReaderBase> reader_;
}; };
......
...@@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
~DoubleBufferReader() { buffer_->Close(); } ~DoubleBufferReader() { buffer_->Close(); }
bool HasNext() const override;
private: private:
void PrefetchThreadFunc(); void PrefetchThreadFunc();
...@@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() { ...@@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
} }
} }
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader { ...@@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader {
void ReInit() override { return; } void ReInit() override { return; }
bool HasNext() const override { return true; }
private: private:
float min_; float min_;
float max_; float max_;
......
...@@ -146,6 +146,7 @@ bool Chunk::Parse(std::istream& sin) { ...@@ -146,6 +146,7 @@ bool Chunk::Parse(std::istream& sin) {
std::string buf; std::string buf;
buf.resize(rec_len); buf.resize(rec_len);
stream.read(&buf[0], rec_len); stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
Add(buf); Add(buf);
} }
return true; return true;
......
...@@ -32,9 +32,9 @@ void Scanner::Reset() { ...@@ -32,9 +32,9 @@ void Scanner::Reset() {
ParseNextChunk(); ParseNextChunk();
} }
const std::string &Scanner::Next() { std::string Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration"); PADDLE_ENFORCE(!eof_, "StopIteration");
auto &rec = cur_chunk_.Record(offset_++); auto rec = cur_chunk_.Record(offset_++);
if (offset_ == cur_chunk_.NumRecords()) { if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk(); ParseNextChunk();
} }
......
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
void Reset(); void Reset();
const std::string& Next(); std::string Next();
bool HasNext() const; bool HasNext() const;
......
...@@ -41,4 +41,29 @@ TEST(WriterScanner, Normal) { ...@@ -41,4 +41,29 @@ TEST(WriterScanner, Normal) {
ASSERT_EQ("CDE", scanner.Next()); ASSERT_EQ("CDE", scanner.Next());
ASSERT_FALSE(scanner.HasNext()); ASSERT_FALSE(scanner.HasNext());
} }
}
TEST(WriterScanner, TinyChunk) {
std::stringstream* stream = new std::stringstream();
{
paddle::recordio::Writer writer(
stream, paddle::recordio::Compressor::kNoCompress, 2 /*max chunk num*/);
writer.Write("ABC");
writer.Write("BCD");
writer.Write("CDE");
writer.Write("DEFG");
writer.Flush();
}
{
stream->seekg(0, std::ios::beg);
std::unique_ptr<std::istream> stream_ptr(stream);
paddle::recordio::Scanner scanner(std::move(stream_ptr));
ASSERT_TRUE(scanner.HasNext());
ASSERT_EQ(scanner.Next(), "ABC");
ASSERT_EQ(scanner.Next(), "BCD");
ASSERT_EQ(scanner.Next(), "CDE");
ASSERT_EQ(scanner.Next(), "DEFG");
ASSERT_FALSE(scanner.HasNext());
}
} }
\ No newline at end of file
...@@ -55,15 +55,10 @@ class TestRecordIO(unittest.TestCase): ...@@ -55,15 +55,10 @@ class TestRecordIO(unittest.TestCase):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
avg_loss_np = [] avg_loss_np = []
for i in xrange(2): # 2 pass # train a pass
batch_id = 0 while not data_file.eof():
while not data_file.eof(): tmp, = exe.run(fetch_list=[avg_loss])
try: avg_loss_np.append(tmp)
batch_id += 1 data_file.reset()
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
except:
print batch_id
break
data_file.reset()
self.assertLess(avg_loss_np[-1], avg_loss_np[0]) self.assertLess(avg_loss_np[-1], avg_loss_np[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册