From 93cab64185edf722dc493d1a00db5032014d836e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 1 Feb 2018 17:38:57 +0800 Subject: [PATCH] Complete CreateRandomReaderOp --- paddle/framework/reader.h | 37 +++++++----- paddle/operators/create_reader_op.cc | 90 ++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 16 deletions(-) create mode 100644 paddle/operators/create_reader_op.cc diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 3954a1bea8a..0669a7c7c75 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -33,8 +33,6 @@ class ReaderBase { class FileReader : public ReaderBase { public: - explicit FileReader(const std::vector& shapes) : shapes_(shapes) {} - DDim shape(size_t idx) const override; std::vector shapes() const override { return shapes_; } @@ -44,8 +42,6 @@ class FileReader : public ReaderBase { class ReaderDecorator : public ReaderBase { public: - explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {} - bool HasNext() const override { return reader_->HasNext(); } DDim shape(size_t idx) const override { return reader_->shape(idx); } @@ -60,19 +56,19 @@ class ReaderDecorator : public ReaderBase { template class RandomReader : public FileReader { public: - RandomReader(const std::vector& shapes, float min, float max) - : FileReader(shapes), min_(min), max_(max) { + void Initialize(const std::vector& shapes, float min, float max) { PADDLE_ENFORCE_LE(min, max, "'min' should be less than or equal to 'max'.(%f vs %f)", min, max); + shapes_ = shapes; + min_ = min; + max_ = max; + unsigned int seed = std::random_device()(); + engine_.seed(seed); + dist_ = std::uniform_real_distribution(min_, max_); } std::vector ReadNext() override { - std::minstd_rand engine; - unsigned int seed = std::random_device()(); - engine.seed(seed); - std::uniform_real_distribution dist(min_, max_); - std::vector res; res.reserve(shapes_.size()); for (const DDim& shape : shapes_) { @@ -85,7 +81,7 @@ class RandomReader : public FileReader { T* data = out.mutable_data(platform::CPUPlace()); int64_t numel = product(shape); for (int64_t i = 0; i < numel; ++i) { - data[i] = dist(engine); + data[i] = dist_(engine_); } res.push_back(out); } @@ -97,16 +93,21 @@ class RandomReader : public FileReader { private: float min_; float max_; + std::minstd_rand engine_; + std::uniform_real_distribution dist_; }; // decorators class ShuffleReader : public ReaderDecorator { public: - ShuffleReader(ReaderBase* reader, int buffer_size) - : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { + void Initialize(ReaderBase* reader, int buffer_size) { + reader_ = reader; + buffer_size_ = buffer_size; + iteration_pos_ = 0; buffer_.reserve(buffer_size); } + std::vector ReadNext() override; private: @@ -117,8 +118,12 @@ class ShuffleReader : public ReaderDecorator { class BatchReader : public ReaderDecorator { public: - BatchReader(ReaderBase* reader, int batch_size) - : ReaderDecorator(reader), batch_size_(batch_size) {} + void Initialize(ReaderBase* reader, int batch_size) { + reader_ = reader; + batch_size_ = batch_size; + buffer_.reserve(batch_size_); + } + std::vector ReadNext() override; private: diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc new file mode 100644 index 00000000000..abdc12087e0 --- /dev/null +++ b/paddle/operators/create_reader_op.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/framework/op_registry.h" +#include "paddle/framework/reader.h" + +namespace paddle { +namespace operators { + +// general infershape +class CreateReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CreateReaderOp should not be null."); + } +}; + +template +class CreateRandomReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + std::vector shapes; + int offset = 0; + for (int len : ranks) { + auto start_it = shape_concat.begin() + offset; + auto end_it = start_it + len; + shapes.push_back( + framework::make_ddim(std::vector(start_it, end_it))); + offset += len; + } + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable>(); + out->Initialize(shapes, Attr("min"), Attr("max")); + } +}; + +class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddOutput("Out", "(RandomReader) The created random reader."); + AddAttr>("shape_concat", + "The concat of all data's shapes."); + AddAttr>( + "ranks", + "The ranks of each data." + "e.g." + "shape_concat = [2,3,4,5,6]" + "ranks = [3,2]" + "It means the reader will generate two data each time," + "whose shapes are [2,3,4] and [5,6] respectively."); + AddAttr("min", "The lower bound of reader's uniform distribution."); + AddAttr("max", "The upper bound of reader's uniform distribution."); + AddComment(R"DOC( + CreateRandomReader Operator + + This Op creates a random reader. + The reader generates random data instead of really reading from files. + Generated data follow an uniform distribution between 'min' and 'max'. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, + ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, + paddle::framework::EmptyGradOpMaker); \ No newline at end of file -- GitLab