From 175aa7ea956a96cb8f2215cf488b61adeee7a065 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sat, 10 Feb 2018 12:16:53 +0800 Subject: [PATCH] add lod and dtype inference (#8329) --- paddle/framework/op_desc.cc | 7 +++ paddle/framework/operator.cc | 4 ++ paddle/framework/reader.cc | 6 --- paddle/framework/shape_inference.cc | 22 +++++++++ paddle/framework/shape_inference.h | 10 +++++ paddle/operators/create_reader_op.cc | 45 ++++++++++++++++--- .../paddle/v2/fluid/tests/test_cpp_reader.py | 14 +++--- 7 files changed, 89 insertions(+), 19 deletions(-) diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index b51afe499bb..90cc9b40236 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -77,6 +77,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { void SetRepeatedDims(const std::string &name, const std::vector &dims) override; + InferShapeVarPtr GetVarPtr(const std::string &name) override; + const OpDesc &op_; const BlockDesc &block_; }; @@ -510,5 +512,10 @@ proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType( return block_.FindVarRecursive(name)->GetType(); } +InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr( + const std::string &name) { + return block_.FindVarRecursive(name); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 52387aabd9d..072dce8929f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -470,6 +470,10 @@ class RuntimeInferShapeContext : public InferShapeContext { return ToVarType(var->Type()); } + InferShapeVarPtr GetVarPtr(const std::string& name) override { + return scope_.FindVar(name); + } + private: const OperatorBase& op_; const Scope& scope_; diff --git a/paddle/framework/reader.cc b/paddle/framework/reader.cc index 928b661aaad..64caf85ed10 100644 --- a/paddle/framework/reader.cc +++ b/paddle/framework/reader.cc @@ -90,7 +90,6 @@ void BatchReader::ReadNext(std::vector* out) { // Merge lod and data LoD batch_lod; - std::vector top_level_lod({0}); for (size_t i = 0; i < buffer_.size(); ++i) { DDim ins_shape = buffer_[i][j].dims(); LoD ins_lod = buffer_[i][j].lod(); @@ -105,15 +104,10 @@ void BatchReader::ReadNext(std::vector* out) { } } } - top_level_lod.push_back( - top_level_lod.back() + - (ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1))); - Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); Copy(buffer_[i][j], platform::CPUPlace(), &dst); dst_offset += ins_shape[0]; } - batch_lod.insert(batch_lod.begin(), top_level_lod); out_tensor.set_lod(batch_lod); out->push_back(out_tensor); } diff --git a/paddle/framework/shape_inference.cc b/paddle/framework/shape_inference.cc index 2f4d4505771..14fc635f07d 100644 --- a/paddle/framework/shape_inference.cc +++ b/paddle/framework/shape_inference.cc @@ -72,6 +72,28 @@ void InferShapeContext::SetReaderDims(const std::string &name, return this->SetRepeatedDims(arg_names[0], dims); } +std::vector InferShapeContext::GetInputVarPtrs( + const std::string &name) { + const std::vector arg_names = Inputs(name); + std::vector res; + res.reserve(arg_names.size()); + std::transform( + arg_names.begin(), arg_names.end(), std::back_inserter(res), + [this](const std::string &name) { return this->GetVarPtr(name); }); + return res; +} + +std::vector InferShapeContext::GetOutputVarPtrs( + const std::string &name) { + const std::vector arg_names = Outputs(name); + std::vector res; + res.reserve(arg_names.size()); + std::transform( + arg_names.begin(), arg_names.end(), std::back_inserter(res), + [this](const std::string &name) { return this->GetVarPtr(name); }); + return res; +} + std::vector InferShapeContext::GetDims( const std::vector &names) const { std::vector ret; diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 7bee8698523..3d4e8298bf5 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -17,10 +17,14 @@ limitations under the License. */ #include "paddle/framework/attribute.h" #include "paddle/framework/ddim.h" #include "paddle/framework/framework.pb.h" +#include "paddle/framework/var_desc.h" +#include "paddle/framework/variable.h" namespace paddle { namespace framework { +using InferShapeVarPtr = boost::variant; + class InferShapeContext { public: virtual ~InferShapeContext() = default; @@ -55,6 +59,9 @@ class InferShapeContext { virtual bool IsRuntime() const = 0; + std::vector GetInputVarPtrs(const std::string &name); + std::vector GetOutputVarPtrs(const std::string &name); + // Note: In while op, we need this to be public void SetDims(const std::vector &names, const std::vector &dims); @@ -67,10 +74,13 @@ class InferShapeContext { const std::vector &dims) = 0; std::vector GetDims(const std::vector &names) const; + std::vector GetVarTypes( const std::vector &names) const; virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0; + + virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; }; } // namespace framework diff --git a/paddle/operators/create_reader_op.cc b/paddle/operators/create_reader_op.cc index 5ba2a25ab4c..71f5202d7e6 100644 --- a/paddle/operators/create_reader_op.cc +++ b/paddle/operators/create_reader_op.cc @@ -42,6 +42,18 @@ class CreateFileReaderInferShape : public framework::InferShapeBase { const auto ranks = ctx->Attrs().Get>("ranks"); std::vector shapes = RestoreShapes(shape_concat, ranks); ctx->SetReaderDims("Out", shapes); + + if (ctx->IsRuntime()) { + const auto lod_levels = ctx->Attrs().Get>("lod_levels"); + PADDLE_ENFORCE_EQ( + lod_levels.size(), shapes.size(), + "The number of 'lod_levels'(%d) doesn't match the number " + "of 'shapes'(%d).", + lod_levels.size(), shapes.size()); + framework::VarDesc* reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + reader->SetLoDLevels(lod_levels); + } } }; @@ -54,11 +66,19 @@ class CreateDecoratedReaderInferShape : public framework::InferShapeBase { PADDLE_ENFORCE(ctx->HasOutput("Out"), "The output decorated reader should not be null."); ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); + + if (ctx->IsRuntime()) { + framework::VarDesc* in_reader = boost::get( + ctx->GetInputVarPtrs("UnderlyingReader")[0]); + framework::VarDesc* out_reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + out_reader->SetLoDLevels(in_reader->GetLoDLevels()); + } } }; -// general var type inference for all readers -class CreateReaderInferVarType : public framework::VarTypeInference { +// general var type inference for file readers +class CreateFileReaderInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDesc& op_desc, framework::BlockDesc* block) const override { @@ -68,6 +88,20 @@ class CreateReaderInferVarType : public framework::VarTypeInference { } }; +// general var type inference for decorated readers +class CreateDecoratedReaderInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + std::string in_reader_name = op_desc.Input("UnderlyingReader")[0]; + framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); + std::string out_reader_name = op_desc.Output("Out")[0]; + framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); + out_reader->SetType(framework::proto::VarDesc::READER); + out_reader->SetDataTypes(in_reader->GetDataTypes()); + } +}; + template class CreateRandomDataGeneratorOp : public framework::OperatorBase { public: @@ -105,6 +139,7 @@ class CreateRandomDataGeneratorOpMaker "ranks = [3,2]" "It means the reader will generate two data each time," "whose shapes are [2,3,4] and [5,6] respectively."); + AddAttr>("lod_levels", "The LoD levels of each data."); AddAttr("min", "The lower bound of reader's uniform distribution."); AddAttr("max", "The upper bound of reader's uniform distribution."); AddComment(R"DOC( @@ -192,14 +227,14 @@ REGISTER_OPERATOR(create_random_data_generator, ops::CreateFileReaderInferShape, ops::CreateRandomDataGeneratorOpMaker, paddle::framework::EmptyGradOpMaker, - ops::CreateReaderInferVarType); + ops::CreateFileReaderInferVarType); REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, ops::CreateDecoratedReaderInferShape, ops::CreateShuffleReaderOpMaker, paddle::framework::EmptyGradOpMaker, - ops::CreateReaderInferVarType); + ops::CreateDecoratedReaderInferVarType); REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, ops::CreateDecoratedReaderInferShape, ops::CreateBatchReaderOpMaker, paddle::framework::EmptyGradOpMaker, - ops::CreateReaderInferVarType); + ops::CreateDecoratedReaderInferVarType); diff --git a/python/paddle/v2/fluid/tests/test_cpp_reader.py b/python/paddle/v2/fluid/tests/test_cpp_reader.py index 970f57ed000..66d6c28ef7d 100644 --- a/python/paddle/v2/fluid/tests/test_cpp_reader.py +++ b/python/paddle/v2/fluid/tests/test_cpp_reader.py @@ -21,7 +21,8 @@ block = prog.current_block() random_reader = block.create_var( type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator") -random_reader.desc.set_lod_levels([0, 0]) +random_reader.desc.set_dtypes( + [fluid.core.DataType.FP32, fluid.core.DataType.FP32]) create_random_data_generator_op = block.append_op( type="create_random_data_generator", @@ -30,11 +31,11 @@ create_random_data_generator_op = block.append_op( "shape_concat": [1, 2, 1, 1], "ranks": [2, 2], "min": 0.0, - "max": 1.0 + "max": 1.0, + 'lod_levels': [0, 0] }) shuffle_reader = block.create_var( type=fluid.core.VarDesc.VarType.READER, name="ShuffleReader") -shuffle_reader.desc.set_lod_levels([0, 0]) create_shuffle_reader_op = block.append_op( type="create_shuffle_reader", @@ -44,7 +45,6 @@ create_shuffle_reader_op = block.append_op( batch_reader = block.create_var( type=fluid.core.VarDesc.VarType.READER, name="BatchReader") -batch_reader.desc.set_lod_levels([1, 1]) create_batch_reader_op = block.append_op( type="create_batch_reader", @@ -62,11 +62,9 @@ read_op = block.append_op( place = fluid.CPUPlace() exe = fluid.Executor(place) -[res1, res2] = exe.run(prog, fetch_list=[out1, out2], return_numpy=False) +[res1, res2] = exe.run(prog, fetch_list=[out1, out2]) -test_pass = res1.lod() == [range(0, 11)] and res1.lod() == [ - range(0, 11) -] and np.array(res1).shape == (10, 2) and np.array(res2).shape == (10, 1) +test_pass = res1.shape == (10, 2) and res2.shape == (10, 1) if not test_pass: exit(1) -- GitLab