提交 d8cc21da 编写于 作者: F fengjiayi

refine inheritance relationship

上级 f32ca636
......@@ -17,7 +17,7 @@
namespace paddle {
namespace framework {
DDim Reader::shape(size_t idx) const {
DDim FileReader::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
......
......@@ -20,32 +20,48 @@
namespace paddle {
namespace framework {
class Reader {
class ReaderBase {
public:
Reader() {}
explicit Reader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
virtual std::vector<LoDTensor> ReadNext() = 0;
virtual bool HasNext() const = 0;
virtual DDim shape(size_t idx) const;
virtual std::vector<DDim> shapes() const { return shapes_; }
virtual DDim shape(size_t idx) const = 0;
virtual std::vector<DDim> shapes() const = 0;
virtual ~Reader() {}
virtual ~ReaderBase() {}
};
private:
// set private to prevent directly access in decorators
// a decorator should access its underlying reader_'s shape, not its own.
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
DDim shape(size_t idx) const override;
std::vector<DDim> shapes() const override { return shapes_; }
protected:
std::vector<DDim> shapes_;
};
class ReaderDecorator : public ReaderBase {
public:
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {}
bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
protected:
ReaderBase* reader_;
};
// file readers
template <typename T>
class RandomReader : public Reader {
class RandomReader : public FileReader {
public:
RandomReader(const std::vector<DDim>& shapes, float min, float max)
: Reader(shapes), min_(min), max_(max) {
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(min, max,
"'min' should be less than or equal to 'max'.(%f vs %f)",
min, max);
......@@ -58,8 +74,8 @@ class RandomReader : public Reader {
std::uniform_real_distribution<float> dist(min_, max_);
std::vector<LoDTensor> res;
res.reserve(shapes().size());
for (const DDim& shape : shapes()) {
res.reserve(shapes_.size());
for (const DDim& shape : shapes_) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
"The rank of input data should be 2 at least.(Now it's %d)",
......@@ -85,37 +101,27 @@ class RandomReader : public Reader {
// decorators
class ShuffleReader : public Reader {
class ShuffleReader : public ReaderDecorator {
public:
ShuffleReader(Reader* reader, int buffer_size)
: reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) {
ShuffleReader(ReaderBase* reader, int buffer_size)
: ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size);
}
std::vector<LoDTensor> ReadNext() override;
bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
private:
Reader* reader_;
int buffer_size_;
std::vector<std::vector<LoDTensor>> buffer_;
size_t iteration_pos_;
};
class BatchReader : public Reader {
class BatchReader : public ReaderDecorator {
public:
BatchReader(Reader* reader, int batch_size)
: reader_(reader), batch_size_(batch_size) {}
BatchReader(ReaderBase* reader, int batch_size)
: ReaderDecorator(reader), batch_size_(batch_size) {}
std::vector<LoDTensor> ReadNext() override;
bool HasNext() const override { return reader_->HasNext(); };
DDim shape(size_t idx) const override { return reader_->shape(idx); }
std::vector<DDim> shapes() const override { return reader_->shapes(); }
private:
Reader* reader_;
int batch_size_;
std::vector<std::vector<LoDTensor>> buffer_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册