提交 3dfd1da1 编写于 作者: F fengjiayi

Complete CreateBatchReaderOp

上级 1696cb0e
......@@ -44,9 +44,9 @@ class FileReader : public ReaderBase {
std::vector<DDim> shapes_;
};
class ReaderDecorator : public ReaderBase {
class DecoratedReader : public ReaderBase {
public:
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {
explicit DecoratedReader(ReaderBase* reader) : reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
......@@ -105,10 +105,10 @@ class RandomReader : public FileReader {
// decorators
class ShuffleReader : public ReaderDecorator {
class ShuffleReader : public DecoratedReader {
public:
ShuffleReader(ReaderBase* reader, int buffer_size)
: ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) {
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size);
}
......@@ -120,10 +120,10 @@ class ShuffleReader : public ReaderDecorator {
size_t iteration_pos_;
};
class BatchReader : public ReaderDecorator {
class BatchReader : public DecoratedReader {
public:
BatchReader(ReaderBase* reader, int batch_size)
: ReaderDecorator(reader), batch_size_(batch_size) {
: DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_);
}
......
......@@ -19,11 +19,22 @@ namespace paddle {
namespace operators {
// general infershape for file readers
class CreateReaderInferShape : public framework::InferShapeBase {
class CreateFileReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CreateReaderOp should not be null.");
"The output file reader should not be null.");
}
};
// general infershape for decorated readers
class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
"Input(Underlying_reader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
}
};
......@@ -83,17 +94,6 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class CreateShuffleReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
"Input(Underlying_reader) of CreateShuffleReaderOp should "
"not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CreateShuffleReaderOp should not be null.");
}
};
class CreateShuffleReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
......@@ -121,7 +121,41 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and output the underlying reader's outputs in a shuffled order.
and yields the underlying reader's outputs in a shuffled order.
)DOC");
}
};
class CreateBatchReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::BatchReader(underlying_reader.Get(),
Attr<int>("batch_size")));
}
};
class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"Underlying_reader",
"(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader.");
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
AddComment(R"DOC(
CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
)DOC");
}
};
......@@ -131,9 +165,14 @@ class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker,
ops::CreateFileReaderInferShape,
ops::CreateRandomReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateShuffleReaderInferShape,
ops::CreateDecoratedReaderInferShape,
ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateBatchReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册