提交 d8cc21da 编写于 作者: F fengjiayi

refine inheritance relationship

上级 f32ca636
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DDim Reader::shape(size_t idx) const { DDim FileReader::shape(size_t idx) const {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
idx, shapes_.size(), idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
......
...@@ -20,32 +20,48 @@ ...@@ -20,32 +20,48 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Reader { class ReaderBase {
public: public:
Reader() {}
explicit Reader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
virtual std::vector<LoDTensor> ReadNext() = 0; virtual std::vector<LoDTensor> ReadNext() = 0;
virtual bool HasNext() const = 0; virtual bool HasNext() const = 0;
virtual DDim shape(size_t idx) const; virtual DDim shape(size_t idx) const = 0;
virtual std::vector<DDim> shapes() const { return shapes_; } virtual std::vector<DDim> shapes() const = 0;
virtual ~Reader() {} virtual ~ReaderBase() {}
};
private: class FileReader : public ReaderBase {
// set private to prevent directly access in decorators public:
// a decorator should access its underlying reader_'s shape, not its own. explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
DDim shape(size_t idx) const override;
std::vector<DDim> shapes() const override { return shapes_; }
protected:
std::vector<DDim> shapes_; std::vector<DDim> shapes_;
}; };
class ReaderDecorator : public ReaderBase {
public:
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {}
bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
protected:
ReaderBase* reader_;
};
// file readers // file readers
template <typename T> template <typename T>
class RandomReader : public Reader { class RandomReader : public FileReader {
public: public:
RandomReader(const std::vector<DDim>& shapes, float min, float max) RandomReader(const std::vector<DDim>& shapes, float min, float max)
: Reader(shapes), min_(min), max_(max) { : FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(min, max, PADDLE_ENFORCE_LE(min, max,
"'min' should be less than or equal to 'max'.(%f vs %f)", "'min' should be less than or equal to 'max'.(%f vs %f)",
min, max); min, max);
...@@ -58,8 +74,8 @@ class RandomReader : public Reader { ...@@ -58,8 +74,8 @@ class RandomReader : public Reader {
std::uniform_real_distribution<float> dist(min_, max_); std::uniform_real_distribution<float> dist(min_, max_);
std::vector<LoDTensor> res; std::vector<LoDTensor> res;
res.reserve(shapes().size()); res.reserve(shapes_.size());
for (const DDim& shape : shapes()) { for (const DDim& shape : shapes_) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
shape.size(), 2, shape.size(), 2,
"The rank of input data should be 2 at least.(Now it's %d)", "The rank of input data should be 2 at least.(Now it's %d)",
...@@ -85,37 +101,27 @@ class RandomReader : public Reader { ...@@ -85,37 +101,27 @@ class RandomReader : public Reader {
// decorators // decorators
class ShuffleReader : public Reader { class ShuffleReader : public ReaderDecorator {
public: public:
ShuffleReader(Reader* reader, int buffer_size) ShuffleReader(ReaderBase* reader, int buffer_size)
: reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) { : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size); buffer_.reserve(buffer_size);
} }
std::vector<LoDTensor> ReadNext() override; std::vector<LoDTensor> ReadNext() override;
bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
private: private:
Reader* reader_;
int buffer_size_; int buffer_size_;
std::vector<std::vector<LoDTensor>> buffer_; std::vector<std::vector<LoDTensor>> buffer_;
size_t iteration_pos_; size_t iteration_pos_;
}; };
class BatchReader : public Reader { class BatchReader : public ReaderDecorator {
public: public:
BatchReader(Reader* reader, int batch_size) BatchReader(ReaderBase* reader, int batch_size)
: reader_(reader), batch_size_(batch_size) {} : ReaderDecorator(reader), batch_size_(batch_size) {}
std::vector<LoDTensor> ReadNext() override; std::vector<LoDTensor> ReadNext() override;
bool HasNext() const override { return reader_->HasNext(); };
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
private: private:
Reader* reader_;
int batch_size_; int batch_size_;
std::vector<std::vector<LoDTensor>> buffer_; std::vector<std::vector<LoDTensor>> buffer_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册