From d8cc21da53e1113aaee3b43ea77d136bbbd204bb Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 1 Feb 2018 12:58:14 +0800 Subject: [PATCH] refine inheritance relationship --- paddle/framework/reader.cc | 2 +- paddle/framework/reader.h | 66 +++++++++++++++++++++----------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index e11662166c..a05bef42ff 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -17,7 +17,7 @@ namespace paddle { namespace framework { -DDim Reader::shape(size_t idx) const { +DDim FileReader::shape(size_t idx) const { PADDLE_ENFORCE_LT( idx, shapes_.size(), "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 58675863e5..3954a1bea8 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -20,32 +20,48 @@ namespace paddle { namespace framework { -class Reader { +class ReaderBase { public: - Reader() {} - explicit Reader(const std::vector& shapes) : shapes_(shapes) {} - virtual std::vector ReadNext() = 0; virtual bool HasNext() const = 0; - virtual DDim shape(size_t idx) const; - virtual std::vector shapes() const { return shapes_; } + virtual DDim shape(size_t idx) const = 0; + virtual std::vector shapes() const = 0; - virtual ~Reader() {} + virtual ~ReaderBase() {} +}; - private: - // set private to prevent directly access in decorators - // a decorator should access its underlying reader_'s shape, not its own. +class FileReader : public ReaderBase { + public: + explicit FileReader(const std::vector& shapes) : shapes_(shapes) {} + + DDim shape(size_t idx) const override; + std::vector shapes() const override { return shapes_; } + + protected: std::vector shapes_; }; +class ReaderDecorator : public ReaderBase { + public: + explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {} + + bool HasNext() const override { return reader_->HasNext(); } + + DDim shape(size_t idx) const override { return reader_->shape(idx); } + std::vector shapes() const override { return reader_->shapes(); } + + protected: + ReaderBase* reader_; +}; + // file readers template -class RandomReader : public Reader { +class RandomReader : public FileReader { public: RandomReader(const std::vector& shapes, float min, float max) - : Reader(shapes), min_(min), max_(max) { + : FileReader(shapes), min_(min), max_(max) { PADDLE_ENFORCE_LE(min, max, "'min' should be less than or equal to 'max'.(%f vs %f)", min, max); @@ -58,8 +74,8 @@ class RandomReader : public Reader { std::uniform_real_distribution dist(min_, max_); std::vector res; - res.reserve(shapes().size()); - for (const DDim& shape : shapes()) { + res.reserve(shapes_.size()); + for (const DDim& shape : shapes_) { PADDLE_ENFORCE_GE( shape.size(), 2, "The rank of input data should be 2 at least.(Now it's %d)", @@ -85,37 +101,27 @@ class RandomReader : public Reader { // decorators -class ShuffleReader : public Reader { +class ShuffleReader : public ReaderDecorator { public: - ShuffleReader(Reader* reader, int buffer_size) - : reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) { + ShuffleReader(ReaderBase* reader, int buffer_size) + : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { buffer_.reserve(buffer_size); } std::vector ReadNext() override; - bool HasNext() const override { return reader_->HasNext(); } - - DDim shape(size_t idx) const override { return reader_->shape(idx); } - std::vector shapes() const override { return reader_->shapes(); } private: - Reader* reader_; int buffer_size_; std::vector> buffer_; size_t iteration_pos_; }; -class BatchReader : public Reader { +class BatchReader : public ReaderDecorator { public: - BatchReader(Reader* reader, int batch_size) - : reader_(reader), batch_size_(batch_size) {} + BatchReader(ReaderBase* reader, int batch_size) + : ReaderDecorator(reader), batch_size_(batch_size) {} std::vector ReadNext() override; - bool HasNext() const override { return reader_->HasNext(); }; - - DDim shape(size_t idx) const override { return reader_->shape(idx); } - std::vector shapes() const override { return reader_->shapes(); } private: - Reader* reader_; int batch_size_; std::vector> buffer_; }; -- GitLab