提交 7eedced8 编写于 作者: Y Yu Yang

Polish RecordIO

上级 cfca8a3a
......@@ -33,6 +33,8 @@ class ReaderBase {
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
virtual bool HasNext() const = 0;
virtual ~ReaderBase() {}
protected:
......@@ -53,6 +55,8 @@ class DecoratedReader : public ReaderBase {
void ReInit() override { reader_->ReInit(); }
bool HasNext() const override { return reader_->HasNext(); }
protected:
ReaderBase* reader_;
};
......@@ -74,6 +78,8 @@ class ReaderHolder {
reader_->set_shapes(shapes);
}
bool HasNext() const { return reader_->HasNext(); }
private:
std::unique_ptr<ReaderBase> reader_;
};
......
......@@ -37,6 +37,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
~DoubleBufferReader() { buffer_->Close(); }
bool HasNext() const override;
private:
void PrefetchThreadFunc();
......@@ -106,6 +108,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
}
}
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
} // namespace reader
} // namespace operators
} // namespace paddle
......
......@@ -52,6 +52,8 @@ class RandomDataGenerator : public framework::FileReader {
void ReInit() override { return; }
bool HasNext() const override { return true; }
private:
float min_;
float max_;
......
......@@ -146,6 +146,7 @@ bool Chunk::Parse(std::istream& sin) {
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
Add(buf);
}
return true;
......
......@@ -32,9 +32,9 @@ void Scanner::Reset() {
ParseNextChunk();
}
const std::string &Scanner::Next() {
std::string Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration");
auto &rec = cur_chunk_.Record(offset_++);
auto rec = cur_chunk_.Record(offset_++);
if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk();
}
......
......@@ -28,7 +28,7 @@ public:
void Reset();
const std::string& Next();
std::string Next();
bool HasNext() const;
......
......@@ -42,3 +42,28 @@ TEST(WriterScanner, Normal) {
ASSERT_FALSE(scanner.HasNext());
}
}
TEST(WriterScanner, TinyChunk) {
std::stringstream* stream = new std::stringstream();
{
paddle::recordio::Writer writer(
stream, paddle::recordio::Compressor::kNoCompress, 2 /*max chunk num*/);
writer.Write("ABC");
writer.Write("BCD");
writer.Write("CDE");
writer.Write("DEFG");
writer.Flush();
}
{
stream->seekg(0, std::ios::beg);
std::unique_ptr<std::istream> stream_ptr(stream);
paddle::recordio::Scanner scanner(std::move(stream_ptr));
ASSERT_TRUE(scanner.HasNext());
ASSERT_EQ(scanner.Next(), "ABC");
ASSERT_EQ(scanner.Next(), "BCD");
ASSERT_EQ(scanner.Next(), "CDE");
ASSERT_EQ(scanner.Next(), "DEFG");
ASSERT_FALSE(scanner.HasNext());
}
}
\ No newline at end of file
......@@ -55,15 +55,10 @@ class TestRecordIO(unittest.TestCase):
exe.run(fluid.default_startup_program())
avg_loss_np = []
for i in xrange(2): # 2 pass
batch_id = 0
# train a pass
while not data_file.eof():
try:
batch_id += 1
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
except:
print batch_id
break
data_file.reset()
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册