提交 3dfd1da1 编写于 作者: F fengjiayi

Complete CreateBatchReaderOp

上级 1696cb0e
...@@ -44,9 +44,9 @@ class FileReader : public ReaderBase { ...@@ -44,9 +44,9 @@ class FileReader : public ReaderBase {
std::vector<DDim> shapes_; std::vector<DDim> shapes_;
}; };
class ReaderDecorator : public ReaderBase { class DecoratedReader : public ReaderBase {
public: public:
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) { explicit DecoratedReader(ReaderBase* reader) : reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
...@@ -105,10 +105,10 @@ class RandomReader : public FileReader { ...@@ -105,10 +105,10 @@ class RandomReader : public FileReader {
// decorators // decorators
class ShuffleReader : public ReaderDecorator { class ShuffleReader : public DecoratedReader {
public: public:
ShuffleReader(ReaderBase* reader, int buffer_size) ShuffleReader(ReaderBase* reader, int buffer_size)
: ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) { : DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size); buffer_.reserve(buffer_size);
} }
...@@ -120,10 +120,10 @@ class ShuffleReader : public ReaderDecorator { ...@@ -120,10 +120,10 @@ class ShuffleReader : public ReaderDecorator {
size_t iteration_pos_; size_t iteration_pos_;
}; };
class BatchReader : public ReaderDecorator { class BatchReader : public DecoratedReader {
public: public:
BatchReader(ReaderBase* reader, int batch_size) BatchReader(ReaderBase* reader, int batch_size)
: ReaderDecorator(reader), batch_size_(batch_size) { : DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
......
...@@ -19,11 +19,22 @@ namespace paddle { ...@@ -19,11 +19,22 @@ namespace paddle {
namespace operators { namespace operators {
// general infershape for file readers // general infershape for file readers
class CreateReaderInferShape : public framework::InferShapeBase { class CreateFileReaderInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CreateReaderOp should not be null."); "The output file reader should not be null.");
}
};
// general infershape for decorated readers
class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
"Input(Underlying_reader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
} }
}; };
...@@ -83,17 +94,6 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,17 +94,6 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class CreateShuffleReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
"Input(Underlying_reader) of CreateShuffleReaderOp should "
"not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CreateShuffleReaderOp should not be null.");
}
};
class CreateShuffleReaderOp : public framework::OperatorBase { class CreateShuffleReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
...@@ -121,7 +121,41 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -121,7 +121,41 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReader Operator CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader' A shuffle reader takes another reader as its 'underlying reader'
and output the underlying reader's outputs in a shuffled order. and yields the underlying reader's outputs in a shuffled order.
)DOC");
}
};
class CreateBatchReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::BatchReader(underlying_reader.Get(),
Attr<int>("batch_size")));
}
};
class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"Underlying_reader",
"(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader.");
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
AddComment(R"DOC(
CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
)DOC"); )DOC");
} }
}; };
...@@ -131,9 +165,14 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -131,9 +165,14 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>, REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, ops::CreateFileReaderInferShape,
ops::CreateRandomReaderOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateShuffleReaderInferShape, ops::CreateDecoratedReaderInferShape,
ops::CreateShuffleReaderOpMaker, ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateBatchReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册