From f32ca6369099f5d3776ae87d431b9b39ea8eba3e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 31 Jan 2018 18:46:45 +0800 Subject: [PATCH] draft of Reader classes --- paddle/framework/CMakeLists.txt | 2 + paddle/framework/reader.cc | 107 +++++++++++++++++++++++++------- paddle/framework/reader.h | 83 +++++++++++++++++++++---- 3 files changed, 159 insertions(+), 33 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8c28709a68b..7eec91f9070 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -24,6 +24,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) +cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) + cc_test(variable_test SRCS variable_test.cc) cc_library(threadpool SRCS threadpool.cc DEPS enforce) diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index 7f80dd7fc10..e11662166c6 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -17,35 +17,100 @@ namespace paddle { namespace framework { -DDim Reader::shape(int idx) const { +DDim Reader::shape(size_t idx) const { PADDLE_ENFORCE_LT( idx, shapes_.size(), "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, shapes_.size()); + return shapes_[idx]; } -int RandomReader::ReadNext(std::vector* outs) { - PADDLE_ENFORCE_EQ( - shapes_.size(), outs.size(), - "shapes_.size() is %d, while outs.size() is %d. They are not equal.", - shapes_.size(), outs.size()); - std::minstd_rand engine; - unsigned int seed = std::random_device()(); - engine.seed(seed); - std::uniform_real_distribution dist(min_, max_); - for (int idx = 0; idx < shapes_.size(); ++idx) { - DDim shape = shapes_[idx]; - LoDTensor* out = outs[idx]; - int64_t numel = out->numel(); - PADDLE_ENFORCE_EQ(product(shape), numel, - "The product of %d'th shape is %lld, while the " - "corresponding out's numel is %lld. They are not equal.", - idx, product(shape), numel); - for (int64_t i = 0; i < numel, ++i) { - out[i] = dist(engine); +std::vector ShuffleReader::ReadNext() { + if (iteration_pos_ >= buffer_.size()) { + // Reload buffer with new data + buffer_.clear(); + for (int i = 0; i < buffer_size_; ++i) { + if (reader_->HasNext()) { + buffer_.push_back(reader_->ReadNext()); + } else { + break; + } } + std::random_shuffle(buffer_.begin(), buffer_.end()); + iteration_pos_ = 0; } - return 0; + if (buffer_.empty()) { + std::vector empty_res; + return empty_res; + } + return buffer_[iteration_pos_++]; +} + +std::vector BatchReader::ReadNext() { + buffer_.clear(); + for (int i = 0; i < batch_size_; ++i) { + if (reader_->HasNext()) { + buffer_.push_back(reader_->ReadNext()); + } else { + break; + } + } + // Concat instances + std::vector res; + if (buffer_.empty()) { + return res; + } + int out_num = buffer_[0].size(); + res.reserve(out_num); + for (int j = 0; j < out_num; ++j) { + // Merge shape and check date type + std::type_index batch_type = buffer_[0][j].type(); + DDim batch_shape = buffer_[0][j].dims(); + for (size_t i = 1; i < buffer_.size(); ++i) { + std::type_index ins_type = buffer_[i][j].type(); + DDim ins_shape = buffer_[i][j].dims(); + PADDLE_ENFORCE_EQ(batch_type, ins_type); + PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()), + slice_ddim(ins_shape, 1, ins_shape.size())); + PADDLE_ENFORCE_GT(ins_shape[0], 0); + batch_shape[0] += ins_shape[0]; + } + + LoDTensor out; + out.Resize(batch_shape); + out.mutable_data(platform::CPUPlace(), batch_type); + int64_t dst_offset = 0; + + // Merge lod and data + LoD batch_lod; + std::vector top_level_lod({0}); + for (size_t i = 0; i < buffer_.size(); ++i) { + DDim ins_shape = buffer_[i][j].dims(); + LoD ins_lod = buffer_[i][j].lod(); + if (i == 0) { + batch_lod = ins_lod; + } else { + PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size()); + for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) { + auto& lod_level = batch_lod[level_idx]; + for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) { + lod_level.push_back(ins_lod[level_idx][k] + lod_level.back()); + } + } + } + top_level_lod.push_back( + top_level_lod.back() + + (ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1))); + + Tensor dst = out.Slice(dst_offset, dst_offset + ins_shape[0]); + Copy(buffer_[i][j], platform::CPUPlace(), &dst); + dst_offset += ins_shape[0]; + } + batch_lod.insert(batch_lod.begin(), top_level_lod); + out.set_lod(batch_lod); + res.push_back(out); + } + return res; } } // namespace framework } // namespace paddle diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index eed9c18d087..58675863e56 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -22,20 +22,61 @@ namespace framework { class Reader { public: - virtual int ReadNext(std::vector* outs) = 0; - DDim shape(int idx) const; + Reader() {} + explicit Reader(const std::vector& shapes) : shapes_(shapes) {} + + virtual std::vector ReadNext() = 0; + virtual bool HasNext() const = 0; + + virtual DDim shape(size_t idx) const; + virtual std::vector shapes() const { return shapes_; } + + virtual ~Reader() {} private: + // set private to prevent directly access in decorators + // a decorator should access its underlying reader_'s shape, not its own. std::vector shapes_; }; // file readers +template class RandomReader : public Reader { public: RandomReader(const std::vector& shapes, float min, float max) - : shapes_(shapes), min_(min), max_(max) {} - int ReadNext(std::vector* outs) override; + : Reader(shapes), min_(min), max_(max) { + PADDLE_ENFORCE_LE(min, max, + "'min' should be less than or equal to 'max'.(%f vs %f)", + min, max); + } + + std::vector ReadNext() override { + std::minstd_rand engine; + unsigned int seed = std::random_device()(); + engine.seed(seed); + std::uniform_real_distribution dist(min_, max_); + + std::vector res; + res.reserve(shapes().size()); + for (const DDim& shape : shapes()) { + PADDLE_ENFORCE_GE( + shape.size(), 2, + "The rank of input data should be 2 at least.(Now it's %d)", + shape.size()); + LoDTensor out; + out.Resize(shape); + T* data = out.mutable_data(platform::CPUPlace()); + int64_t numel = product(shape); + for (int64_t i = 0; i < numel; ++i) { + data[i] = dist(engine); + } + res.push_back(out); + } + return res; + } + + bool HasNext() const override { return true; } private: float min_; @@ -44,22 +85,40 @@ class RandomReader : public Reader { // decorators -class BatchReader : public Reader { +class ShuffleReader : public Reader { public: - BatchReader(const Reader* reader) : reader_(reader) {} - int ReadNext(std::vector* outs) override; + ShuffleReader(Reader* reader, int buffer_size) + : reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) { + buffer_.reserve(buffer_size); + } + std::vector ReadNext() override; + bool HasNext() const override { return reader_->HasNext(); } + + DDim shape(size_t idx) const override { return reader_->shape(idx); } + std::vector shapes() const override { return reader_->shapes(); } private: - const Reader* reader_; + Reader* reader_; + int buffer_size_; + std::vector> buffer_; + size_t iteration_pos_; }; -class ShuffleReader : public Reader { +class BatchReader : public Reader { public: - ShuffleReader(const Reader* reader) : reader_(reader) {} - int ReadNext(std::vector* outs) override; + BatchReader(Reader* reader, int batch_size) + : reader_(reader), batch_size_(batch_size) {} + std::vector ReadNext() override; + bool HasNext() const override { return reader_->HasNext(); }; + + DDim shape(size_t idx) const override { return reader_->shape(idx); } + std::vector shapes() const override { return reader_->shapes(); } private: - const Reader* reader_; + Reader* reader_; + int batch_size_; + std::vector> buffer_; }; + } // namespace framework } // namespace paddle -- GitLab