From 4d8345e3ac0ba79a17359a72e940ade284c0b1a9 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 6 Mar 2018 17:31:16 +0800 Subject: [PATCH] Extract create_reader_op to three files --- paddle/fluid/framework/reader.cc | 87 ------- paddle/fluid/framework/reader.h | 79 +----- paddle/fluid/operators/CMakeLists.txt | 8 +- paddle/fluid/operators/create_reader_op.cc | 246 ------------------ paddle/fluid/operators/detail/CMakeLists.txt | 4 +- paddle/fluid/operators/reader/CMakeLists.txt | 4 + .../reader/create_batch_reader_op.cc | 137 ++++++++++ .../reader/create_random_data_generator_op.cc | 110 ++++++++ .../reader/create_shuffle_reader_op.cc | 97 +++++++ .../operators/reader/reader_op_registry.cc | 116 +++++++++ .../operators/reader/reader_op_registry.h | 75 ++++++ 11 files changed, 548 insertions(+), 415 deletions(-) delete mode 100644 paddle/fluid/operators/create_reader_op.cc create mode 100644 paddle/fluid/operators/reader/CMakeLists.txt create mode 100644 paddle/fluid/operators/reader/create_batch_reader_op.cc create mode 100644 paddle/fluid/operators/reader/create_random_data_generator_op.cc create mode 100644 paddle/fluid/operators/reader/create_shuffle_reader_op.cc create mode 100644 paddle/fluid/operators/reader/reader_op_registry.cc create mode 100644 paddle/fluid/operators/reader/reader_op_registry.h diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index dc1caa72a4c..91879d6d458 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -25,92 +25,5 @@ DDim ReaderBase::shape(size_t idx) const { return shapes_[idx]; } -void ShuffleReader::ReadNext(std::vector* out) { - if (iteration_pos_ >= buffer_.size()) { - // Reload buffer with new data - buffer_.clear(); - buffer_.reserve(buffer_size_); - for (int i = 0; i < buffer_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - } else { - break; - } - } - // TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be - // optimize. - std::random_shuffle(buffer_.begin(), buffer_.end()); - iteration_pos_ = 0; - } - out->clear(); - if (!buffer_.empty()) { - std::swap(*out, buffer_[iteration_pos_++]); - } - // if buffer_ is empty, the 'out' will return as an empty vector. -} - -void BatchReader::ReadNext(std::vector* out) { - buffer_.clear(); - buffer_.reserve(batch_size_); - for (int i = 0; i < batch_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - } else { - break; - } - } - // Concat instances - out->clear(); - if (buffer_.empty()) { - // if buffer_ is empty, the 'out' will return as an empty vector. - return; - } - int out_num = buffer_[0].size(); - out->reserve(out_num); - for (int j = 0; j < out_num; ++j) { - // Merge shape and check date type - std::type_index batch_type = buffer_[0][j].type(); - DDim batch_shape = buffer_[0][j].dims(); - for (size_t i = 1; i < buffer_.size(); ++i) { - std::type_index ins_type = buffer_[i][j].type(); - DDim ins_shape = buffer_[i][j].dims(); - PADDLE_ENFORCE_EQ(batch_type, ins_type); - PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()), - slice_ddim(ins_shape, 1, ins_shape.size())); - PADDLE_ENFORCE_GT(ins_shape[0], 0); - batch_shape[0] += ins_shape[0]; - } - - LoDTensor out_tensor; - out_tensor.Resize(batch_shape); - out_tensor.mutable_data(platform::CPUPlace(), batch_type); - int64_t dst_offset = 0; - - // Merge lod and data - LoD batch_lod; - for (size_t i = 0; i < buffer_.size(); ++i) { - DDim ins_shape = buffer_[i][j].dims(); - LoD ins_lod = buffer_[i][j].lod(); - if (i == 0) { - batch_lod = ins_lod; - } else { - PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size()); - for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) { - auto& lod_level = batch_lod[level_idx]; - for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) { - lod_level.push_back(ins_lod[level_idx][k] + lod_level.back()); - } - } - } - Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); - TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst); - dst_offset += ins_shape[0]; - } - out_tensor.set_lod(batch_lod); - out->push_back(out_tensor); - } -} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 4a5eba5fb73..27ab6e750c2 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -60,83 +60,8 @@ class DecoratedReader : public ReaderBase { ReaderBase* reader_; }; -// file readers - -template -class RandomDataGenerator : public FileReader { - public: - RandomDataGenerator(const std::vector& shapes, float min, float max) - : FileReader(shapes), min_(min), max_(max) { - PADDLE_ENFORCE_LE( - min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); - unsigned int seed = std::random_device()(); - engine_.seed(seed); - dist_ = std::uniform_real_distribution(min_, max_); - } - - void ReadNext(std::vector* out) override { - out->clear(); - out->reserve(shapes_.size()); - for (const DDim& shape : shapes_) { - PADDLE_ENFORCE_GE( - shape.size(), 2, - "The rank of reader's output data should be 2 at least.(Now it's %d)", - shape.size()); - LoDTensor out_tensor; - out_tensor.Resize(shape); - T* data = out_tensor.mutable_data(platform::CPUPlace()); - int64_t numel = product(shape); - for (int64_t i = 0; i < numel; ++i) { - data[i] = dist_(engine_); - } - out->push_back(out_tensor); - } - } - - bool HasNext() const override { return true; } - - void ReInit() override { return; } - - private: - float min_; - float max_; - std::minstd_rand engine_; - std::uniform_real_distribution dist_; -}; - -// decorated readers - -class ShuffleReader : public DecoratedReader { - public: - ShuffleReader(ReaderBase* reader, int buffer_size) - : DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { - buffer_.reserve(buffer_size); - } - - void ReadNext(std::vector* out) override; - - private: - int buffer_size_; - std::vector> buffer_; - size_t iteration_pos_; -}; - -class BatchReader : public DecoratedReader { - public: - BatchReader(ReaderBase* reader, int batch_size) - : DecoratedReader(reader), batch_size_(batch_size) { - buffer_.reserve(batch_size_); - } - - void ReadNext(std::vector* out) override; - - private: - int batch_size_; - std::vector> buffer_; -}; - -// The ReaderHolder is used as readers' unified wrapper, -// making it easier to access different type readers in Variables. +// The ReaderHolder is used as reader' unified wrapper, +// making it easier to access different type reader in Variables. class ReaderHolder { public: void Reset(ReaderBase* reader) { reader_.reset(reader); } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e1c02ec1613..5ad58908f5b 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -70,7 +70,7 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "create_reader_op") + foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() @@ -128,8 +128,8 @@ else() set(DEPS_OPS ${DEPS_OPS} nccl_op) endif() +add_subdirectory(detail) if(WITH_DISTRIBUTE) - add_subdirectory(detail) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(send_op DEPS ${DISTRIBUTE_DEPS}) @@ -170,7 +170,6 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) -op_library(create_reader_op DEPS reader) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv) @@ -189,7 +188,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) op_library(${src}) endforeach() -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") @@ -204,3 +203,4 @@ if(WITH_GPU) endif() cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) +add_subdirectory(reader) diff --git a/paddle/fluid/operators/create_reader_op.cc b/paddle/fluid/operators/create_reader_op.cc deleted file mode 100644 index 17ed7e24eca..00000000000 --- a/paddle/fluid/operators/create_reader_op.cc +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/reader.h" - -namespace paddle { -namespace operators { - -static std::vector RestoreShapes( - const std::vector& shape_concat, const std::vector& ranks) { - std::vector res; - int offset = 0; - for (int len : ranks) { - auto start_it = shape_concat.begin() + offset; - auto end_it = start_it + len; - res.push_back(framework::make_ddim(std::vector(start_it, end_it))); - offset += len; - } - return res; -} - -// general infershape for file readers -class CreateFileReaderInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "The output file reader should not be null."); - const auto shape_concat = - ctx->Attrs().Get>("shape_concat"); - const auto ranks = ctx->Attrs().Get>("ranks"); - std::vector shapes = RestoreShapes(shape_concat, ranks); - ctx->SetReaderDims("Out", shapes); - - if (ctx->IsRuntime()) { - const auto lod_levels = ctx->Attrs().Get>("lod_levels"); - PADDLE_ENFORCE_EQ( - lod_levels.size(), shapes.size(), - "The number of 'lod_levels'(%d) doesn't match the number " - "of 'shapes'(%d).", - lod_levels.size(), shapes.size()); - framework::VarDesc* reader = - boost::get(ctx->GetOutputVarPtrs("Out")[0]); - reader->SetLoDLevels(lod_levels); - } - } -}; - -// general infershape for decorated readers -class CreateDecoratedReaderInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"), - "Input(UnderlyingReader) should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "The output decorated reader should not be null."); - ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); - - if (ctx->IsRuntime()) { - framework::VarDesc* in_reader = boost::get( - ctx->GetInputVarPtrs("UnderlyingReader")[0]); - framework::VarDesc* out_reader = - boost::get(ctx->GetOutputVarPtrs("Out")[0]); - out_reader->SetLoDLevels(in_reader->GetLoDLevels()); - } - } -}; - -// general var type inference for file readers -class CreateFileReaderInferVarType : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - std::string reader_name = op_desc.Output("Out")[0]; - framework::VarDesc* reader = block->FindVarRecursive(reader_name); - reader->SetType(framework::proto::VarType::READER); - } -}; - -// general var type inference for decorated readers -class CreateDecoratedReaderInferVarType : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - std::string in_reader_name = op_desc.Input("UnderlyingReader")[0]; - framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); - std::string out_reader_name = op_desc.Output("Out")[0]; - framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); - out_reader->SetType(framework::proto::VarType::READER); - out_reader->SetDataTypes(in_reader->GetDataTypes()); - } -}; - -template -class CreateRandomDataGeneratorOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - const auto& shape_concat = Attr>("shape_concat"); - const auto& ranks = Attr>("ranks"); - PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); - PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), - int(shape_concat.size()), - "The accumulate of all ranks should be equal to the " - "shape concat's length."); - std::vector shapes = RestoreShapes(shape_concat, ranks); - auto* out = scope.FindVar(Output("Out")) - ->template GetMutable(); - out->Reset(new framework::RandomDataGenerator(shapes, Attr("min"), - Attr("max"))); - } -}; - -class CreateRandomDataGeneratorOpMaker - : public framework::OpProtoAndCheckerMaker { - public: - CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(op_proto, op_checker) { - AddOutput("Out", "(ReaderHolder) The created random reader."); - AddAttr>("shape_concat", - "The concat of all data's shapes."); - AddAttr>( - "ranks", - "The ranks of each data." - "e.g." - "shape_concat = [2,3,4,5,6]" - "ranks = [3,2]" - "It means the reader will generate two data each time," - "whose shapes are [2,3,4] and [5,6] respectively."); - AddAttr>("lod_levels", "The LoD levels of each data."); - AddAttr("min", "The lower bound of reader's uniform distribution."); - AddAttr("max", "The upper bound of reader's uniform distribution."); - AddComment(R"DOC( - CreateRandomDataGenerator Operator - - This Op creates a random reader. - The reader generates random data instead of really reading from files. - Generated data follow an uniform distribution between 'min' and 'max'. - )DOC"); - } -}; - -class CreateShuffleReaderOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) - ->Get(); - auto* out = scope.FindVar(Output("Out")) - ->template GetMutable(); - out->Reset(new framework::ShuffleReader(underlying_reader.Get(), - Attr("buffer_size"))); - } -}; - -class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { - public: - CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(op_proto, op_checker) { - AddInput( - "UnderlyingReader", - "(ReaderHolder) The underlying reader for creating a shuffle reader."); - AddOutput("Out", "(ReaderHolder) The created shuffle reader."); - AddAttr("buffer_size", "The shuffle buffer size.").GreaterThan(0); - AddComment(R"DOC( - CreateShuffleReader Operator - - A shuffle reader takes another reader as its 'underlying reader' - and yields the underlying reader's outputs in a shuffled order. - )DOC"); - } -}; - -class CreateBatchReaderOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) - ->Get(); - auto* out = scope.FindVar(Output("Out")) - ->template GetMutable(); - out->Reset(new framework::BatchReader(underlying_reader.Get(), - Attr("batch_size"))); - } -}; - -class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { - public: - CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(op_proto, op_checker) { - AddInput( - "UnderlyingReader", - "(ReaderHolder) The underlying reader for creating a batch reader."); - AddOutput("Out", "(ReaderHolder) The created batch reader."); - AddAttr("batch_size", - "How many instances the batch reader yields each time.") - .GreaterThan(0); - AddComment(R"DOC( - CreateBatchReader Operator - - A batch reader takes another reader as its 'underlying reader', - gathers the underlying reader's outputs and then yields them in batches. - )DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(create_random_data_generator, - ops::CreateRandomDataGeneratorOp, - ops::CreateFileReaderInferShape, - ops::CreateRandomDataGeneratorOpMaker, - paddle::framework::EmptyGradOpMaker, - ops::CreateFileReaderInferVarType); -REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, - ops::CreateDecoratedReaderInferShape, - ops::CreateShuffleReaderOpMaker, - paddle::framework::EmptyGradOpMaker, - ops::CreateDecoratedReaderInferVarType); -REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, - ops::CreateDecoratedReaderInferShape, - ops::CreateBatchReaderOpMaker, - paddle::framework::EmptyGradOpMaker, - ops::CreateDecoratedReaderInferVarType); diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 571a75c9dcd..0581bd2ac55 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1 +1,3 @@ -grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) +if(WITH_DISTRIBUTE) + grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) +endif() diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt new file mode 100644 index 00000000000..16c93f3f8b1 --- /dev/null +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -0,0 +1,4 @@ +cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader) +op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry) +op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry) +op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry) diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc new file mode 100644 index 00000000000..bac043a5529 --- /dev/null +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class BatchReader : public framework::DecoratedReader { + public: + BatchReader(ReaderBase* reader, int batch_size) + : DecoratedReader(reader), batch_size_(batch_size) { + buffer_.reserve(batch_size_); + } + + void ReadNext(std::vector* out) override; + + private: + int batch_size_; + std::vector> buffer_; +}; + +class CreateBatchReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) + ->Get(); + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset( + new BatchReader(underlying_reader.Get(), Attr("batch_size"))); + } +}; + +class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { + public: + CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : DecoratedReaderMakerBase(op_proto, op_checker) { + AddAttr("batch_size", + "How many instances the batch reader yields each time.") + .GreaterThan(0); + AddComment(R"DOC( + CreateBatchReader Operator + + A batch reader takes another reader as its 'underlying reader', + gathers the underlying reader's outputs and then yields them in batches. + )DOC"); + } +}; + +void BatchReader::ReadNext(std::vector* out) { + buffer_.clear(); + buffer_.reserve(batch_size_); + for (int i = 0; i < batch_size_; ++i) { + if (reader_->HasNext()) { + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); + } else { + break; + } + } + // Concat instances + out->clear(); + if (buffer_.empty()) { + // if buffer_ is empty, the 'out' will return as an empty vector. + return; + } + int out_num = buffer_[0].size(); + out->reserve(out_num); + for (int j = 0; j < out_num; ++j) { + // Merge shape and check date type + std::type_index batch_type = buffer_[0][j].type(); + framework::DDim batch_shape = buffer_[0][j].dims(); + for (size_t i = 1; i < buffer_.size(); ++i) { + std::type_index ins_type = buffer_[i][j].type(); + framework::DDim ins_shape = buffer_[i][j].dims(); + PADDLE_ENFORCE_EQ(batch_type, ins_type); + PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()), + slice_ddim(ins_shape, 1, ins_shape.size())); + PADDLE_ENFORCE_GT(ins_shape[0], 0); + batch_shape[0] += ins_shape[0]; + } + + framework::LoDTensor out_tensor; + out_tensor.Resize(batch_shape); + out_tensor.mutable_data(platform::CPUPlace(), batch_type); + int64_t dst_offset = 0; + + // Merge lod and data + framework::LoD batch_lod; + for (size_t i = 0; i < buffer_.size(); ++i) { + framework::DDim ins_shape = buffer_[i][j].dims(); + framework::LoD ins_lod = buffer_[i][j].lod(); + if (i == 0) { + batch_lod = ins_lod; + } else { + PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size()); + for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) { + auto& lod_level = batch_lod[level_idx]; + for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) { + lod_level.push_back(ins_lod[level_idx][k] + lod_level.back()); + } + } + } + auto dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); + TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst); + dst_offset += ins_shape[0]; + } + out_tensor.set_lod(batch_lod); + out->push_back(out_tensor); + } +} + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::reader; +REGISTER_DECORATED_READER_OPERATOR(create_batch_reader, + ops::CreateBatchReaderOp, + ops::CreateBatchReaderOpMaker); diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc new file mode 100644 index 00000000000..f77ab8ab196 --- /dev/null +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +template +class RandomDataGenerator : public framework::FileReader { + public: + RandomDataGenerator(const std::vector& shapes, float min, + float max) + : FileReader(shapes), min_(min), max_(max) { + PADDLE_ENFORCE_LE( + min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); + unsigned int seed = std::random_device()(); + engine_.seed(seed); + dist_ = std::uniform_real_distribution(min_, max_); + } + + void ReadNext(std::vector* out) override { + out->clear(); + out->reserve(shapes_.size()); + for (const framework::DDim& shape : shapes_) { + PADDLE_ENFORCE_GE( + shape.size(), 2, + "The rank of reader's output data should be 2 at least.(Now it's %d)", + shape.size()); + framework::LoDTensor out_tensor; + out_tensor.Resize(shape); + T* data = out_tensor.mutable_data(platform::CPUPlace()); + int64_t numel = framework::product(shape); + for (int64_t i = 0; i < numel; ++i) { + data[i] = dist_(engine_); + } + out->push_back(out_tensor); + } + } + + bool HasNext() const override { return true; } + + void ReInit() override { return; } + + private: + float min_; + float max_; + std::minstd_rand engine_; + std::uniform_real_distribution dist_; +}; + +template +class CreateRandomDataGeneratorOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& shape_concat = Attr>("shape_concat"); + const auto& ranks = Attr>("ranks"); + PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); + PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), + int(shape_concat.size()), + "The accumulate of all ranks should be equal to the " + "shape concat's length."); + std::vector shapes = RestoreShapes(shape_concat, ranks); + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset(new RandomDataGenerator(shapes, Attr("min"), + Attr("max"))); + } +}; + +class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase { + public: + CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : FileReaderMakerBase(op_proto, op_checker) { + AddAttr("min", "The lower bound of reader's uniform distribution."); + AddAttr("max", "The upper bound of reader's uniform distribution."); + AddComment(R"DOC( + CreateRandomDataGenerator Operator + + This Op creates a random reader. + The reader generates random data instead of really reading from files. + Generated data follow an uniform distribution between 'min' and 'max'. + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::reader; +REGISTER_FILE_READER_OPERATOR(create_random_data_generator, + ops::CreateRandomDataGeneratorOp, + ops::CreateRandomDataGeneratorOpMaker); diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc new file mode 100644 index 00000000000..3e8b463efc9 --- /dev/null +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class ShuffleReader : public framework::DecoratedReader { + public: + ShuffleReader(ReaderBase* reader, int buffer_size) + : DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { + buffer_.reserve(buffer_size); + } + + void ReadNext(std::vector* out) override; + + private: + int buffer_size_; + std::vector> buffer_; + size_t iteration_pos_; +}; + +void ShuffleReader::ReadNext(std::vector* out) { + if (iteration_pos_ >= buffer_.size()) { + // Reload buffer with new data + buffer_.clear(); + buffer_.reserve(buffer_size_); + for (int i = 0; i < buffer_size_; ++i) { + if (reader_->HasNext()) { + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); + } else { + break; + } + } + // TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be + // optimize. + std::random_shuffle(buffer_.begin(), buffer_.end()); + iteration_pos_ = 0; + } + out->clear(); + if (!buffer_.empty()) { + std::swap(*out, buffer_[iteration_pos_++]); + } + // if buffer_ is empty, the 'out' will return as an empty vector. +} + +class CreateShuffleReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) + ->Get(); + auto* out = scope.FindVar(Output("Out")) + ->template GetMutable(); + out->Reset( + new ShuffleReader(underlying_reader.Get(), Attr("buffer_size"))); + } +}; + +class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase { + public: + CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : DecoratedReaderMakerBase(op_proto, op_checker) { + AddAttr("buffer_size", "The shuffle buffer size.").GreaterThan(0); + AddComment(R"DOC( + CreateShuffleReader Operator + + A shuffle reader takes another reader as its 'underlying reader' + and yields the underlying reader's outputs in a shuffled order. + )DOC"); + } +}; +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::reader; +REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader, + ops::CreateShuffleReaderOp, + ops::CreateShuffleReaderOpMaker); diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc new file mode 100644 index 00000000000..7ea4f4b8d9f --- /dev/null +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +std::vector RestoreShapes(const std::vector& shape_concat, + const std::vector& ranks) { + std::vector res; + int offset = 0; + for (int len : ranks) { + auto start_it = shape_concat.begin() + offset; + auto end_it = start_it + len; + res.push_back(framework::make_ddim(std::vector(start_it, end_it))); + offset += len; + } + return res; +} + +FileReaderMakerBase::FileReaderMakerBase( + framework::OpProtoAndCheckerMaker::OpProto* op_proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddOutput("Out", "(ReaderHolder) The created random reader."); + AddAttr>("shape_concat", "The concat of all data's shapes."); + AddAttr>( + "ranks", + "The ranks of each data." + "e.g." + "shape_concat = [2,3,4,5,6]" + "ranks = [3,2]" + "It means the reader will generate two data each time," + "whose shapes are [2,3,4] and [5,6] respectively."); + AddAttr>("lod_levels", "The LoD levels of each data."); +} + +void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "The output file reader should not be null."); + const auto shape_concat = ctx->Attrs().Get>("shape_concat"); + const auto ranks = ctx->Attrs().Get>("ranks"); + std::vector shapes = RestoreShapes(shape_concat, ranks); + ctx->SetReaderDims("Out", shapes); + + if (ctx->IsRuntime()) { + const auto lod_levels = ctx->Attrs().Get>("lod_levels"); + PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(), + "The number of 'lod_levels'(%d) doesn't match the number " + "of 'shapes'(%d).", + lod_levels.size(), shapes.size()); + framework::VarDesc* reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + reader->SetLoDLevels(lod_levels); + } +} + +void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const { + std::string reader_name = op_desc.Output("Out")[0]; + framework::VarDesc* reader = block->FindVarRecursive(reader_name); + reader->SetType(framework::proto::VarType::READER); +} + +void DecoratedReaderInferShape::operator()( + framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"), + "Input(UnderlyingReader) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "The output decorated reader should not be null."); + ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); + + if (ctx->IsRuntime()) { + framework::VarDesc* in_reader = boost::get( + ctx->GetInputVarPtrs("UnderlyingReader")[0]); + framework::VarDesc* out_reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + out_reader->SetLoDLevels(in_reader->GetLoDLevels()); + } +} +void DecoratedReaderInferVarType::operator()( + const framework::OpDesc& op_desc, framework::BlockDesc* block) const { + std::string in_reader_name = op_desc.Input("UnderlyingReader")[0]; + framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); + std::string out_reader_name = op_desc.Output("Out")[0]; + framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); + out_reader->SetType(framework::proto::VarType::READER); + out_reader->SetDataTypes(in_reader->GetDataTypes()); +} + +DecoratedReaderMakerBase::DecoratedReaderMakerBase( + framework::OpProtoAndCheckerMaker::OpProto* op_proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddInput("UnderlyingReader", + "(ReaderHolder) The underlying reader for creating a batch reader."); + AddOutput("Out", "(ReaderHolder) The created batch reader."); +} + +} // namespace reader + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h new file mode 100644 index 00000000000..d1f0498f469 --- /dev/null +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -0,0 +1,75 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/reader.h" + +namespace paddle { +namespace operators { +namespace reader { + +extern std::vector RestoreShapes( + const std::vector& shape_concat, const std::vector& ranks); + +class FileReaderMakerBase : public framework::OpProtoAndCheckerMaker { + public: + FileReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker); +}; + +class FileReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override; +}; + +class FileReaderInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override; +}; + +// general infershape for decorated reader +class DecoratedReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override; +}; + +// general var type inference for decorated reader +class DecoratedReaderInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override; +}; + +class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker { + public: + DecoratedReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker); +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +#define REGISTER_FILE_READER_OPERATOR(op_name, ...) \ + REGISTER_OPERATOR(op_name, __VA_ARGS__, \ + paddle::operators::reader::FileReaderInferShape, \ + paddle::framework::EmptyGradOpMaker, \ + paddle::operators::reader::FileReaderInferVarType) + +#define REGISTER_DECORATED_READER_OPERATOR(op_name, ...) \ + REGISTER_OPERATOR(op_name, __VA_ARGS__, \ + paddle::operators::reader::DecoratedReaderInferShape, \ + paddle::framework::EmptyGradOpMaker, \ + paddle::operators::reader::DecoratedReaderInferVarType) -- GitLab