提交 1696cb0e 编写于 作者: F fengjiayi

Complete CreateShuffleReaderOp

上级 93cab641
...@@ -33,6 +33,10 @@ class ReaderBase { ...@@ -33,6 +33,10 @@ class ReaderBase {
class FileReader : public ReaderBase { class FileReader : public ReaderBase {
public: public:
explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
DDim shape(size_t idx) const override; DDim shape(size_t idx) const override;
std::vector<DDim> shapes() const override { return shapes_; } std::vector<DDim> shapes() const override { return shapes_; }
...@@ -42,6 +46,10 @@ class FileReader : public ReaderBase { ...@@ -42,6 +46,10 @@ class FileReader : public ReaderBase {
class ReaderDecorator : public ReaderBase { class ReaderDecorator : public ReaderBase {
public: public:
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
bool HasNext() const override { return reader_->HasNext(); } bool HasNext() const override { return reader_->HasNext(); }
DDim shape(size_t idx) const override { return reader_->shape(idx); } DDim shape(size_t idx) const override { return reader_->shape(idx); }
...@@ -56,13 +64,11 @@ class ReaderDecorator : public ReaderBase { ...@@ -56,13 +64,11 @@ class ReaderDecorator : public ReaderBase {
template <typename T> template <typename T>
class RandomReader : public FileReader { class RandomReader : public FileReader {
public: public:
void Initialize(const std::vector<DDim>& shapes, float min, float max) { RandomReader(const std::vector<DDim>& shapes, float min, float max)
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(min, max, PADDLE_ENFORCE_LE(min, max,
"'min' should be less than or equal to 'max'.(%f vs %f)", "'min' should be less than or equal to 'max'.(%f vs %f)",
min, max); min, max);
shapes_ = shapes;
min_ = min;
max_ = max;
unsigned int seed = std::random_device()(); unsigned int seed = std::random_device()();
engine_.seed(seed); engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(min_, max_); dist_ = std::uniform_real_distribution<float>(min_, max_);
...@@ -101,10 +107,8 @@ class RandomReader : public FileReader { ...@@ -101,10 +107,8 @@ class RandomReader : public FileReader {
class ShuffleReader : public ReaderDecorator { class ShuffleReader : public ReaderDecorator {
public: public:
void Initialize(ReaderBase* reader, int buffer_size) { ShuffleReader(ReaderBase* reader, int buffer_size)
reader_ = reader; : ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_size_ = buffer_size;
iteration_pos_ = 0;
buffer_.reserve(buffer_size); buffer_.reserve(buffer_size);
} }
...@@ -118,9 +122,8 @@ class ShuffleReader : public ReaderDecorator { ...@@ -118,9 +122,8 @@ class ShuffleReader : public ReaderDecorator {
class BatchReader : public ReaderDecorator { class BatchReader : public ReaderDecorator {
public: public:
void Initialize(ReaderBase* reader, int batch_size) { BatchReader(ReaderBase* reader, int batch_size)
reader_ = reader; : ReaderDecorator(reader), batch_size_(batch_size) {
batch_size_ = batch_size;
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
...@@ -131,5 +134,21 @@ class BatchReader : public ReaderDecorator { ...@@ -131,5 +134,21 @@ class BatchReader : public ReaderDecorator {
std::vector<std::vector<LoDTensor>> buffer_; std::vector<std::vector<LoDTensor>> buffer_;
}; };
class ReaderHolder {
public:
void Reset(ReaderBase* reader) { reader_.reset(reader); }
ReaderBase* Get() const { return reader_.get(); }
std::vector<LoDTensor> ReadNext() { return reader_->ReadNext(); }
bool HasNext() const { return reader_->HasNext(); }
DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
private:
std::unique_ptr<ReaderBase> reader_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -62,7 +62,7 @@ function(op_library TARGET) ...@@ -62,7 +62,7 @@ function(op_library TARGET)
endif() endif()
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "create_reader_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
...@@ -153,6 +153,7 @@ op_library(recurrent_op DEPS executor) ...@@ -153,6 +153,7 @@ op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(create_reader_op DEPS reader)
# Regist multiple Kernel to pybind # Regist multiple Kernel to pybind
if (WITH_GPU) if (WITH_GPU)
...@@ -178,7 +179,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) ...@@ -178,7 +179,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
op_library(${src}) op_library(${src})
endforeach() endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n")
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// general infershape // general infershape for file readers
class CreateReaderInferShape : public framework::InferShapeBase { class CreateReaderInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { void operator()(framework::InferShapeContext* ctx) const override {
...@@ -35,6 +35,7 @@ class CreateRandomReaderOp : public framework::OperatorBase { ...@@ -35,6 +35,7 @@ class CreateRandomReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat"); const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks"); const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()), int(shape_concat.size()),
"The accumulate of all ranks should be equal to the " "The accumulate of all ranks should be equal to the "
...@@ -49,8 +50,9 @@ class CreateRandomReaderOp : public framework::OperatorBase { ...@@ -49,8 +50,9 @@ class CreateRandomReaderOp : public framework::OperatorBase {
offset += len; offset += len;
} }
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::RandomReader<T>>(); ->template GetMutable<framework::ReaderHolder>();
out->Initialize(shapes, Attr<float>("min"), Attr<float>("max")); out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
} }
}; };
...@@ -58,7 +60,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -58,7 +60,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) { : OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(RandomReader) The created random reader."); AddOutput("Out", "(ReaderHolder) The created random reader.");
AddAttr<std::vector<int>>("shape_concat", AddAttr<std::vector<int>>("shape_concat",
"The concat of all data's shapes."); "The concat of all data's shapes.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
...@@ -81,10 +83,57 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,10 +83,57 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class CreateShuffleReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
"Input(Underlying_reader) of CreateShuffleReaderOp should "
"not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of CreateShuffleReaderOp should not be null.");
}
};
class CreateShuffleReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::ShuffleReader(underlying_reader.Get(),
Attr<int>("buffer_size")));
}
};
class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"Underlying_reader",
"(ReaderHolder) The underlying reader for creating a shuffle reader.");
AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
AddComment(R"DOC(
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and output the underlying reader's outputs in a shuffled order.
)DOC");
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>, REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker, ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
\ No newline at end of file REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateShuffleReaderInferShape,
ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册