From 1696cb0e510a8d52427b6ca96900bab4e03b5af1 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 1 Feb 2018 21:10:16 +0800 Subject: [PATCH] Complete CreateShuffleReaderOp --- paddle/framework/reader.h | 41 +++++++++++++------ paddle/operators/CMakeLists.txt | 5 ++- paddle/operators/create_reader_op.cc | 59 +++++++++++++++++++++++++--- 3 files changed, 87 insertions(+), 18 deletions(-) diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 0669a7c7c7..18a34bfd17 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -33,6 +33,10 @@ class ReaderBase { class FileReader : public ReaderBase { public: + explicit FileReader(const std::vector& shapes) : shapes_(shapes) { + PADDLE_ENFORCE(!shapes_.empty()); + } + DDim shape(size_t idx) const override; std::vector shapes() const override { return shapes_; } @@ -42,6 +46,10 @@ class FileReader : public ReaderBase { class ReaderDecorator : public ReaderBase { public: + explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) { + PADDLE_ENFORCE_NOT_NULL(reader_); + } + bool HasNext() const override { return reader_->HasNext(); } DDim shape(size_t idx) const override { return reader_->shape(idx); } @@ -56,13 +64,11 @@ class ReaderDecorator : public ReaderBase { template class RandomReader : public FileReader { public: - void Initialize(const std::vector& shapes, float min, float max) { + RandomReader(const std::vector& shapes, float min, float max) + : FileReader(shapes), min_(min), max_(max) { PADDLE_ENFORCE_LE(min, max, "'min' should be less than or equal to 'max'.(%f vs %f)", min, max); - shapes_ = shapes; - min_ = min; - max_ = max; unsigned int seed = std::random_device()(); engine_.seed(seed); dist_ = std::uniform_real_distribution(min_, max_); @@ -101,10 +107,8 @@ class RandomReader : public FileReader { class ShuffleReader : public ReaderDecorator { public: - void Initialize(ReaderBase* reader, int buffer_size) { - reader_ = reader; - buffer_size_ = buffer_size; - iteration_pos_ = 0; + ShuffleReader(ReaderBase* reader, int buffer_size) + : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { buffer_.reserve(buffer_size); } @@ -118,9 +122,8 @@ class ShuffleReader : public ReaderDecorator { class BatchReader : public ReaderDecorator { public: - void Initialize(ReaderBase* reader, int batch_size) { - reader_ = reader; - batch_size_ = batch_size; + BatchReader(ReaderBase* reader, int batch_size) + : ReaderDecorator(reader), batch_size_(batch_size) { buffer_.reserve(batch_size_); } @@ -131,5 +134,21 @@ class BatchReader : public ReaderDecorator { std::vector> buffer_; }; +class ReaderHolder { + public: + void Reset(ReaderBase* reader) { reader_.reset(reader); } + + ReaderBase* Get() const { return reader_.get(); } + + std::vector ReadNext() { return reader_->ReadNext(); } + bool HasNext() const { return reader_->HasNext(); } + + DDim shape(size_t idx) const { return reader_->shape(idx); } + std::vector shapes() const { return reader_->shapes(); } + + private: + std::unique_ptr reader_; +}; + } // namespace framework } // namespace paddle diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 48cf5816cc..3684eb0dcc 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -62,7 +62,7 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") + foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "create_reader_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() @@ -153,6 +153,7 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) +op_library(create_reader_op DEPS reader) # Regist multiple Kernel to pybind if (WITH_GPU) @@ -178,7 +179,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index abdc12087e..29b487e10b 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -18,7 +18,7 @@ namespace paddle { namespace operators { -// general infershape +// general infershape for file readers class CreateReaderInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override { @@ -35,6 +35,7 @@ class CreateRandomReaderOp : public framework::OperatorBase { const platform::Place& dev_place) const override { const auto& shape_concat = Attr>("shape_concat"); const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), int(shape_concat.size()), "The accumulate of all ranks should be equal to the " @@ -49,8 +50,9 @@ class CreateRandomReaderOp : public framework::OperatorBase { offset += len; } auto* out = scope.FindVar(Output("Out")) - ->template GetMutable>(); - out->Initialize(shapes, Attr("min"), Attr("max")); + ->template GetMutable(); + out->Reset(new framework::RandomReader(shapes, Attr("min"), + Attr("max"))); } }; @@ -58,7 +60,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { public: CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(op_proto, op_checker) { - AddOutput("Out", "(RandomReader) The created random reader."); + AddOutput("Out", "(ReaderHolder) The created random reader."); AddAttr>("shape_concat", "The concat of all data's shapes."); AddAttr>( @@ -81,10 +83,57 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class CreateShuffleReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), + "Input(Underlying_reader) of CreateShuffleReaderOp should " + "not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of CreateShuffleReaderOp should not be null."); + } +}; + +class CreateShuffleReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) + ->Get(); + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new framework::ShuffleReader(underlying_reader.Get(), + Attr("buffer_size"))); + } +}; + +class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddInput( + "Underlying_reader", + "(ReaderHolder) The underlying reader for creating a shuffle reader."); + AddOutput("Out", "(ReaderHolder) The created shuffle reader."); + AddAttr("buffer_size", "The shuffle buffer size.").GreaterThan(0); + AddComment(R"DOC( + CreateShuffleReader Operator + + A shuffle reader takes another reader as its 'underlying reader' + and output the underlying reader's outputs in a shuffled order. + )DOC"); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, - paddle::framework::EmptyGradOpMaker); \ No newline at end of file + paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, + ops::CreateShuffleReaderInferShape, + ops::CreateShuffleReaderOpMaker, + paddle::framework::EmptyGradOpMaker); -- GitLab