提交 1696cb0e 编写于 作者: F fengjiayi

Complete CreateShuffleReaderOp

上级 93cab641
......@@ -33,6 +33,10 @@ class ReaderBase {
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
DDim shape(size_t idx) const override;
std::vector<DDim> shapes() const override { return shapes_; }
......@@ -42,6 +46,10 @@ class FileReader : public ReaderBase {
class ReaderDecorator : public ReaderBase {
public:
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); }
......@@ -56,13 +64,11 @@ class ReaderDecorator : public ReaderBase {
template <typename T>
class RandomReader : public FileReader {
public:
void Initialize(const std::vector<DDim>& shapes, float min, float max) {
RandomReader(const std::vector<DDim>& shapes, float min, float max)
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(min, max,
"'min' should be less than or equal to 'max'.(%f vs %f)",
min, max);
shapes_ = shapes;
min_ = min;
max_ = max;
unsigned int seed = std::random_device()();
engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(min_, max_);
......@@ -101,10 +107,8 @@ class RandomReader : public FileReader {
class ShuffleReader : public ReaderDecorator {
public:
void Initialize(ReaderBase* reader, int buffer_size) {
reader_ = reader;
buffer_size_ = buffer_size;
iteration_pos_ = 0;
ShuffleReader(ReaderBase* reader, int buffer_size)
: ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size);
}
......@@ -118,9 +122,8 @@ class ShuffleReader : public ReaderDecorator {
class BatchReader : public ReaderDecorator {
public:
void Initialize(ReaderBase* reader, int batch_size) {
reader_ = reader;
batch_size_ = batch_size;
BatchReader(ReaderBase* reader, int batch_size)
: ReaderDecorator(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_);
}
......@@ -131,5 +134,21 @@ class BatchReader : public ReaderDecorator {
std::vector<std::vector<LoDTensor>> buffer_;
};
class ReaderHolder {
public:
void Reset(ReaderBase* reader) { reader_.reset(reader); }
ReaderBase* Get() const { return reader_.get(); }
std::vector<LoDTensor> ReadNext() { return reader_->ReadNext(); }
bool HasNext() const { return reader_->HasNext(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
private:
std::unique_ptr<ReaderBase> reader_;
};
} // namespace framework
} // namespace paddle
......@@ -62,7 +62,7 @@ function(op_library TARGET)
endif()
# Define operators that don't need pybind here.
foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op")
foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "create_reader_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......@@ -153,6 +153,7 @@ op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
op_library(create_reader_op DEPS reader)
# Regist multiple Kernel to pybind
if (WITH_GPU)
......@@ -178,7 +179,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
op_library(${src})
endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n")
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
......
......@@ -18,7 +18,7 @@
namespace paddle {
namespace operators {
// general infershape
// general infershape for file readers
class CreateReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
......@@ -35,6 +35,7 @@ class CreateRandomReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
......@@ -49,8 +50,9 @@ class CreateRandomReaderOp : public framework::OperatorBase {
offset += len;
}
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::RandomReader<T>>();
out->Initialize(shapes, Attr<float>("min"), Attr<float>("max"));
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
}
};
......@@ -58,7 +60,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(RandomReader) The created random reader.");
AddOutput("Out", "(ReaderHolder) The created random reader.");
AddAttr<std::vector<int>>("shape_concat",
"The concat of all data's shapes.");
AddAttr<std::vector<int>>(
......@@ -81,6 +83,49 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class CreateShuffleReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
"Input(Underlying_reader) of CreateShuffleReaderOp should "
"not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CreateShuffleReaderOp should not be null.");
}
};
class CreateShuffleReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::ShuffleReader(underlying_reader.Get(),
Attr<int>("buffer_size")));
}
};
class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"Underlying_reader",
"(ReaderHolder) The underlying reader for creating a shuffle reader.");
AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
AddComment(R"DOC(
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and output the underlying reader's outputs in a shuffled order.
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -88,3 +133,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateShuffleReaderInferShape,
ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册