From 1010e39bdf738029fcb78b0d388a91dfdebdda2f Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 6 Feb 2018 12:39:51 +0800 Subject: [PATCH] Add ReadOp --- paddle/framework/framework.proto | 4 +- paddle/framework/op_desc.cc | 29 +++++++-- paddle/framework/operator.cc | 26 ++++++-- paddle/framework/reader.cc | 40 ++++++------ paddle/framework/reader.h | 32 +++++----- paddle/framework/shape_inference.cc | 14 +++++ paddle/framework/shape_inference.h | 3 +- paddle/operators/read_op.cc | 94 +++++++++++++++++++++++++++++ 8 files changed, 193 insertions(+), 49 deletions(-) create mode 100644 paddle/operators/read_op.cc diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index f65ccae6e6a..d7be1a7352d 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -116,7 +116,7 @@ message LoDTensorArrayDesc { optional int32 lod_level = 2 [ default = 0 ]; } -message Reader { repeated LoDTensorDesc lod_tensor = 1; } +message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } message VarDesc { enum VarType { @@ -136,7 +136,7 @@ message VarDesc { optional LoDTensorDesc lod_tensor = 4; optional TensorDesc selected_rows = 5; optional LoDTensorArrayDesc tensor_array = 6; - optional Reader reader = 7; + optional ReaderDesc reader = 7; } message BlockDesc { diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index ad361852ec9..772ec26895e 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -72,6 +72,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { void SetDim(const std::string &name, const DDim &dim) override; + std::vector GetRepeatedDim(const std::string &name) const override; + const OpDesc &op_; const BlockDesc &block_; }; @@ -457,22 +459,37 @@ const std::vector &CompileTimeInferShapeContext::Outputs( DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { auto var = block_.FindVarRecursive(name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + DDim res; try { auto shape = var->GetShape(); - if (shape.empty()) { - return framework::make_ddim({0UL}); - } else { - return framework::make_ddim(var->GetShape()); - } + res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape); } catch (...) { VLOG(5) << "GetDim of variable " << name << " error"; std::rethrow_exception(std::current_exception()); } + return res; +} + +std::vector CompileTimeInferShapeContext::GetRepeatedDim( + const std::string &name) const { + auto var = block_.FindVarRecursive(name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); + std::vector res; + try { + auto shapes = var->GetShapes(); + for (const auto &s : shapes) { + res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s)); + } + } catch (...) { + VLOG(5) << "GetRepeatedDim of variable " << name << " error."; + std::rethrow_exception(std::current_exception()); + } + return res; } void CompileTimeInferShapeContext::SetDim(const std::string &name, const DDim &dim) { - block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); + block_.FindVarRecursive(name)->SetShape(vectorize(dim)); } bool CompileTimeInferShapeContext::IsRuntime() const { return false; } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 81fa8cf4774..1aa111dc76d 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext { if (length == 0) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs", - name); + PADDLE_ENFORCE_EQ(length, 1UL, + "Input %s should not have more than one inputs", name); auto ipt = ins[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; @@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext { if (length == 0) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs", - name); + PADDLE_ENFORCE_EQ(length, 1UL, + "Output %s should not have more than one inputs", name); auto ipt = outs[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; @@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext { } else if (var->IsType()) { return var->Get().GetCompleteDims(); } else { - PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", - name, var->Type().name()); + PADDLE_THROW( + "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " + "type_id is %s.", + name, var->Type().name()); + } + } + + std::vector GetRepeatedDim(const std::string& name) const override { + Variable* var = scope_.FindVar(name); + if (var->IsType()) { + return var->Get().shapes(); + } else { + PADDLE_THROW( + "Only ReaderHolder support 'GetRepeatedDim', but Variable %s's " + "type_id is %s.", + name, var->Type().name()); } } diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index a05bef42ffa..76cbc827ba5 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -25,13 +25,15 @@ DDim FileReader::shape(size_t idx) const { return shapes_[idx]; } -std::vector ShuffleReader::ReadNext() { +void ShuffleReader::ReadNext(std::vector* out) { if (iteration_pos_ >= buffer_.size()) { // Reload buffer with new data buffer_.clear(); + buffer_.reverse(buffer_size_); for (int i = 0; i < buffer_size_; ++i) { if (reader_->HasNext()) { - buffer_.push_back(reader_->ReadNext()); + buffer.push_back(std::vector()); + reader_->ReadNext(&buffer.back()); } else { break; } @@ -39,29 +41,32 @@ std::vector ShuffleReader::ReadNext() { std::random_shuffle(buffer_.begin(), buffer_.end()); iteration_pos_ = 0; } - if (buffer_.empty()) { - std::vector empty_res; - return empty_res; + out->clear(); + if (!buffer_.empty()) { + std::swap(*out, buffer_[iteration_pos_++]); } - return buffer_[iteration_pos_++]; + // if buffer_ is empty, the 'out' will return as an empty vector. } -std::vector BatchReader::ReadNext() { +void BatchReader::ReadNext(std::vector* out) { buffer_.clear(); + buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { if (reader_->HasNext()) { - buffer_.push_back(reader_->ReadNext()); + buffer_.push_back(std::vector()); + reader_->ReadNext(&buffer_.back()); } else { break; } } // Concat instances - std::vector res; + out.clear(); if (buffer_.empty()) { - return res; + // if buffer_ is empty, the 'out' will return as an empty vector. + return; } int out_num = buffer_[0].size(); - res.reserve(out_num); + out->reserve(out_num); for (int j = 0; j < out_num; ++j) { // Merge shape and check date type std::type_index batch_type = buffer_[0][j].type(); @@ -76,9 +81,9 @@ std::vector BatchReader::ReadNext() { batch_shape[0] += ins_shape[0]; } - LoDTensor out; - out.Resize(batch_shape); - out.mutable_data(platform::CPUPlace(), batch_type); + LoDTensor out_tensor; + out_tensor.Resize(batch_shape); + out_tensor.mutable_data(platform::CPUPlace(), batch_type); int64_t dst_offset = 0; // Merge lod and data @@ -102,15 +107,14 @@ std::vector BatchReader::ReadNext() { top_level_lod.back() + (ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1))); - Tensor dst = out.Slice(dst_offset, dst_offset + ins_shape[0]); + Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); Copy(buffer_[i][j], platform::CPUPlace(), &dst); dst_offset += ins_shape[0]; } batch_lod.insert(batch_lod.begin(), top_level_lod); - out.set_lod(batch_lod); - res.push_back(out); + out_tensor.set_lod(batch_lod); + out->push_back(out_tensor); } - return res; } } // namespace framework } // namespace paddle diff --git a/paddle/framework/reader.h b/paddle/framework/reader.h index f450e67689a..523ff28c990 100644 --- a/paddle/framework/reader.h +++ b/paddle/framework/reader.h @@ -15,14 +15,14 @@ #pragma once #include "paddle/framework/ddim.h" -#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/lod_tensor_array.h" namespace paddle { namespace framework { class ReaderBase { public: - virtual std::vector ReadNext() = 0; + virtual void ReadNext(std::vector* out) = 0; virtual bool HasNext() const = 0; virtual DDim shape(size_t idx) const = 0; @@ -73,24 +73,24 @@ class RandomReader : public FileReader { dist_ = std::uniform_real_distribution(min_, max_); } - std::vector ReadNext() override { - std::vector res; - res.reserve(shapes_.size()); + void ReadNext(std::vector* out) override { + out.clear(); + out.reserve(shapes_.size()); for (const DDim& shape : shapes_) { PADDLE_ENFORCE_GE( shape.size(), 2, - "The rank of input data should be 2 at least.(Now it's %d)", + "The rank of reader's output data should be 2 at least.(Now it's %d)", shape.size()); - LoDTensor out; - out.Resize(shape); - T* data = out.mutable_data(platform::CPUPlace()); + LoDTensor out_tensor; + out_tensor.Resize(shape); + T* data = out_tensor.mutable_data(platform::CPUPlace()); int64_t numel = product(shape); for (int64_t i = 0; i < numel; ++i) { data[i] = dist_(engine_); } - res.push_back(out); + out.push_back(out_tensor); } - return res; + return out; } bool HasNext() const override { return true; } @@ -111,11 +111,11 @@ class ShuffleReader : public DecoratedReader { buffer_.reserve(buffer_size); } - std::vector ReadNext() override; + void ReadNext(std::vector* out) override; private: int buffer_size_; - std::vector> buffer_; + std::vector> buffer_; size_t iteration_pos_; }; @@ -126,11 +126,11 @@ class BatchReader : public DecoratedReader { buffer_.reserve(batch_size_); } - std::vector ReadNext() override; + void ReadNext(std::vector* out) override; private: int batch_size_; - std::vector> buffer_; + std::vector> buffer_; }; // The ReaderHolder is used as readers' unified wrapper, @@ -141,7 +141,7 @@ class ReaderHolder { ReaderBase* Get() const { return reader_.get(); } - std::vector ReadNext() { return reader_->ReadNext(); } + void ReadNext(std::vector* out) { reader_->ReadNext(out); } bool HasNext() const { return reader_->HasNext(); } DDim shape(size_t idx) const { return reader_->shape(idx); } diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index a0fa467291b..4a8acfb87ff 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -32,6 +32,16 @@ std::vector InferShapeContext::GetInputsDim( return GetDims(arg_names); } +std::vector InferShapeContext::GetReaderDims( + const std::string &name) const { + const std::vector &arg_names = Inputs(name); + PADDLE_ENFORCE_EQ( + arg_names.size(), 1UL, + "Reader input '%s' should hold one element, but now it holds %d", name, + arg_names.size()); + return this->GetRepeatedDims(arg_names[0]); +} + DDim InferShapeContext::GetInputsElementDim(const std::string &name, int idx) const { const std::vector &names = Inputs(name); @@ -61,6 +71,7 @@ std::vector InferShapeContext::GetDims( [this](const std::string &name) { return this->GetDim(name); }); return ret; } + void InferShapeContext::SetDims(const std::vector &names, const std::vector &dims) { size_t length = names.size(); @@ -72,14 +83,17 @@ void InferShapeContext::SetDims(const std::vector &names, SetDim(names[i], dims[i]); } } + std::vector InferShapeContext::GetInputsVarType( const std::string &name) const { return GetVarTypes(Inputs(name)); } + std::vector InferShapeContext::GetOutputsVarType( const std::string &name) const { return GetVarTypes(Outputs(name)); } + std::vector InferShapeContext::GetVarTypes( const std::vector &names) const { std::vector retv; diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 830f199ed14..f1a64e9024b 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -36,8 +36,8 @@ class InferShapeContext { virtual bool HasOutputs(const std::string &name) const = 0; DDim GetInputDim(const std::string &name) const; - std::vector GetInputsDim(const std::string &name) const; + std::vector GetReaderDims(const std::string &name) const DDim; DDim GetInputsElementDim(const std::string &name, int idx) const; void SetOutputDim(const std::string &name, const DDim &dim); @@ -61,6 +61,7 @@ class InferShapeContext { protected: virtual DDim GetDim(const std::string &name) const = 0; virtual void SetDim(const std::string &name, const DDim &dim) = 0; + std::vector GetRepeatedDim(const std::string &name) const = 0; std::vector GetDims(const std::vector &names) const; std::vector GetVarTypes( diff --git a/paddle/operators/read_op.cc b/paddle/operators/read_op.cc new file mode 100644 index 00000000000..c6ff4ba8fee --- /dev/null +++ b/paddle/operators/read_op.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/framework/op_registry.h" +#include "paddle/framework/reader.h" + +namespace paddle { +namespace operators { + +class ReadInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Reader"), + "The ReadOp must take a reader as input."); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), + "The ReadOp should be assigned with output."); + std::vector reader_dims = ctx->GetReaderDims("Reader"); + std::vector out_names = ctx->Outputs("Out"); + PADDLE_ENFORCE_EQ( + reader_dims.size(), out_names.size(), + "The reader's dim number doesn't match the output number."); + ctx->SetOutputsDim("Out", reader_dims); + } +}; + +class ReadInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + std::string reader_name = op_desc.Input("Reader")[0]; + std::vector out_names = op_desc.Output("Out"); + framework::VarDesc reader = block.FindVarRecursive(reader_name); + auto dtypes = reader.GetDataTypes(); + PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); + for (size_t i = 0; i < dtypes.size(); ++i) { + faremwork::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); + out.SetType(framework::proto::DataType::LOD_TENSOR); + out.SetDataType(dtypes[i]); + } + } +}; + +class ReadOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + void Run(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const framework::ReaderHolder& reader = + scope.FindVar(Input("Reader"))->Get(); + if (!reader.HasNext()) { + // what shall we do??? + return; + } + std::vector out_arg_names = Outputs("Out"); + std::vector ins; + reader.ReadNext(&ins); + PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); + for (size_t i = 0; i < ins.size(); ++i) { + auto* out = + scope.FindVar(out_arg_names[i])->GetMutable(); + PADDLE_ENFORCE_EQ(ins[i].dims(), out->dims()); + out->ShareDataWith(ins[i]); + out->set_lod(ins[i].lod()); + } + } +}; + +class ReadOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(op_proto, op_checker) { + AddInput("Reader", "(ReaderHolder) The executed reader."); + AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); + AddComment(R"DOC( + Read Operator + + Execute a given reader once and output data. + )DOC") + } +}; + +} // namespace operators +} // namespace paddle \ No newline at end of file -- GitLab