diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index 18a34bfd170116ca8984a34c8a2661251bffd3ad..8275ea474b41d7b749ab4c6a55781a28f3b36095 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -44,9 +44,9 @@ class FileReader : public ReaderBase { std::vector shapes_; }; -class ReaderDecorator : public ReaderBase { +class DecoratedReader : public ReaderBase { public: - explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) { + explicit DecoratedReader(ReaderBase* reader) : reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } @@ -105,10 +105,10 @@ class RandomReader : public FileReader { // decorators -class ShuffleReader : public ReaderDecorator { +class ShuffleReader : public DecoratedReader { public: ShuffleReader(ReaderBase* reader, int buffer_size) - : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { + : DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { buffer_.reserve(buffer_size); } @@ -120,10 +120,10 @@ class ShuffleReader : public ReaderDecorator { size_t iteration_pos_; }; -class BatchReader : public ReaderDecorator { +class BatchReader : public DecoratedReader { public: BatchReader(ReaderBase* reader, int batch_size) - : ReaderDecorator(reader), batch_size_(batch_size) { + : DecoratedReader(reader), batch_size_(batch_size) { buffer_.reserve(batch_size_); } diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 29b487e10b5c67f15cd9285482e43f34c4ff51a8..9cf27bbfc694b71ced66b56bb09aa14b2dcbe9d7 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -19,11 +19,22 @@ namespace paddle { namespace operators { // general infershape for file readers -class CreateReaderInferShape : public framework::InferShapeBase { +class CreateFileReaderInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of CreateReaderOp should not be null."); + "The output file reader should not be null."); + } +}; + +// general infershape for decorated readers +class CreateDecoratedReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), + "Input(Underlying_reader) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "The output decorated reader should not be null."); } }; @@ -83,17 +94,6 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class CreateShuffleReaderInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"), - "Input(Underlying_reader) of CreateShuffleReaderOp should " - "not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of CreateShuffleReaderOp should not be null."); - } -}; - class CreateShuffleReaderOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; @@ -121,7 +121,41 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { CreateShuffleReader Operator A shuffle reader takes another reader as its 'underlying reader' - and output the underlying reader's outputs in a shuffled order. + and yields the underlying reader's outputs in a shuffled order. + )DOC"); + } +}; + +class CreateBatchReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("Underlying_reader")) + ->Get(); + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new framework::BatchReader(underlying_reader.Get(), + Attr("batch_size"))); + } +}; + +class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddInput( + "Underlying_reader", + "(ReaderHolder) The underlying reader for creating a batch reader."); + AddOutput("Out", "(ReaderHolder) The created batch reader."); + AddAttr("batch_size", + "How many instances the batch reader yields each time.") + .GreaterThan(0); + AddComment(R"DOC( + CreateBatchReader Operator + + A batch reader takes another reader as its 'underlying reader', + gathers the underlying reader's outputs and then yields them in batches. )DOC"); } }; @@ -131,9 +165,14 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp, - ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, + ops::CreateFileReaderInferShape, + ops::CreateRandomReaderOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, - ops::CreateShuffleReaderInferShape, + ops::CreateDecoratedReaderInferShape, ops::CreateShuffleReaderOpMaker, paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, + ops::CreateDecoratedReaderInferShape, + ops::CreateBatchReaderOpMaker, + paddle::framework::EmptyGradOpMaker);