From 7dabee27960b5e043b85aca3ee51568443b326f4 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 5 Feb 2018 15:00:03 +0800 Subject: [PATCH] Add type Reader for VarDesc Add a new type `Reader` for `VarDesc`, which can holds more than one LoDTensor. --- paddle/framework/backward.cc | 4 +- paddle/framework/framework.proto | 10 +- paddle/framework/op_desc.cc | 4 +- paddle/framework/program_desc_test.cc | 4 +- paddle/framework/var_desc.cc | 174 ++++++++++++++++-- paddle/framework/var_desc.h | 20 +- paddle/inference/io.cc | 2 +- paddle/pybind/protobuf.cc | 14 +- .../v2/fluid/tests/test_protobuf_descs.py | 38 ++++ 9 files changed, 246 insertions(+), 24 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 85e693434af..f52a51519fc 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward( auto root_block = program_desc.MutableBlock(root_block_idx); std::string fill_one_op_out = GradVarName(target.Name()); - bool is_scalar = target.Shape() == std::vector{1}; + bool is_scalar = target.GetShape() == std::vector{1}; PADDLE_ENFORCE(is_scalar, "target should be scalar"); VLOG(3) << "backward from loss=" << target.Name() << " data_type=" << target.GetDataType(); @@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward( auto var = root_block->Var(fill_one_op_out); var->SetDataType(target.GetDataType()); - var->SetShape(target.Shape()); + var->SetShape(target.GetShape()); auto& target_grad = retv[target.Name()]; target_grad.name_ = fill_one_op_out; target_grad.block_idx_ = root_block_idx; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 5b6ef03f610..f65ccae6e6a 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -116,6 +116,8 @@ message LoDTensorArrayDesc { optional int32 lod_level = 2 [ default = 0 ]; } +message Reader { repeated LoDTensorDesc lod_tensor = 1; } + message VarDesc { enum VarType { LOD_TENSOR = 1; @@ -126,13 +128,15 @@ message VarDesc { LOD_RANK_TABLE = 6; LOD_TENSOR_ARRAY = 7; PLACE_LIST = 8; + READER = 9; } required string name = 1; required VarType type = 2; - optional LoDTensorDesc lod_tensor = 3; - optional TensorDesc selected_rows = 4; + optional bool persistable = 3 [ default = false ]; + optional LoDTensorDesc lod_tensor = 4; + optional TensorDesc selected_rows = 5; optional LoDTensorArrayDesc tensor_array = 6; - optional bool persistable = 5 [ default = false ]; + optional Reader reader = 7; } message BlockDesc { diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index f554c778450..ad361852ec9 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { auto var = block_.FindVarRecursive(name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); try { - auto shape = var->Shape(); + auto shape = var->GetShape(); if (shape.empty()) { return framework::make_ddim({0UL}); } else { - return framework::make_ddim(var->Shape()); + return framework::make_ddim(var->GetShape()); } } catch (...) { VLOG(5) << "GetDim of variable " << name << " error"; diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc index 59947c9f218..9945aee31b6 100644 --- a/paddle/framework/program_desc_test.cc +++ b/paddle/framework/program_desc_test.cc @@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) { ASSERT_NE(copy, var_before); ASSERT_EQ(copy->Name(), var_before->Name()); ASSERT_EQ(copy->GetType(), var_before->GetType()); - ASSERT_EQ(copy->Shape(), var_before->Shape()); + ASSERT_EQ(copy->GetShape(), var_before->GetShape()); ASSERT_EQ(copy->Proto()->SerializeAsString(), var_before->Proto()->SerializeAsString()); }; @@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ASSERT_NE(restored, var_before); ASSERT_EQ(restored->Name(), var_before->Name()); ASSERT_EQ(restored->GetType(), var_before->GetType()); - ASSERT_EQ(restored->Shape(), var_before->Shape()); + ASSERT_EQ(restored->GetShape(), var_before->GetShape()); ASSERT_EQ(restored->Proto()->SerializeAsString(), var_before->Proto()->SerializeAsString()); }; diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index 62ab6593ef2..44bd2363c83 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector &dims) { VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); } +void VarDesc::SetTensorDescNum(size_t num) { + switch (desc_.type()) { + case proto::VarDesc::READER: { + auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor(); + lod_tensors_ptr->Clear(); + for (size_t i = 0; i < num; ++i) { + lod_tensors_ptr->Add(); + } + return; + } break; + default: + PADDLE_THROW( + "Setting 'sub_tensor_number' is not supported by the type of var %s.", + this->Name()); + } +} + +size_t VarDesc::GetTensorDescNum() const { + switch (desc_.type()) { + case proto::VarDesc::READER: + return desc_.reader().lod_tensor_size(); + break; + default: + PADDLE_THROW( + "Getting 'sub_tensor_number' is not supported by the type of var %s.", + this->Name()); + } +} + +void VarDesc::SetShapes( + const std::vector> &multiple_dims) { + PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(), + "The number of given shapes(%d) doesn't equal to the " + "number of sub tensor.", + multiple_dims.size(), GetTensorDescNum()); + std::vector tensors = mutable_tensor_descs(); + for (size_t i = 0; i < multiple_dims.size(); ++i) { + VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); + } +} + +std::vector VarDesc::GetShape() const { + return RepeatedToVector(tensor_desc().dims()); +} + +std::vector> VarDesc::GetShapes() const { + std::vector descs = tensor_descs(); + std::vector> res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(RepeatedToVector(tensor_desc.dims())); + } + return res; +} + void VarDesc::SetDataType(proto::DataType data_type) { mutable_tensor_desc()->set_data_type(data_type); } -std::vector VarDesc::Shape() const { - return RepeatedToVector(tensor_desc().dims()); +void VarDesc::SetDataTypes( + const std::vector &multiple_data_type) { + PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(), + "The number of given data types(%d) doesn't equal to the " + "number of sub tensor.", + multiple_data_type.size(), GetTensorDescNum()); + std::vector tensor_descs = mutable_tensor_descs(); + for (size_t i = 0; i < multiple_data_type.size(); ++i) { + tensor_descs[i]->set_data_type(multiple_data_type[i]); + } } proto::DataType VarDesc::GetDataType() const { return tensor_desc().data_type(); } +std::vector VarDesc::GetDataTypes() const { + std::vector descs = tensor_descs(); + std::vector res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(tensor_desc.data_type()); + } + return res; +} + void VarDesc::SetLoDLevel(int32_t lod_level) { switch (desc_.type()) { case proto::VarDesc::LOD_TENSOR: @@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { desc_.mutable_tensor_array()->set_lod_level(lod_level); break; default: - PADDLE_THROW("Tensor type=%d does not support LoDLevel", - desc_.tensor_array().lod_level()); + PADDLE_THROW( + "Setting 'lod_level' is not supported by the type of var %s.", + this->Name()); + } +} + +void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { + PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(), + "The number of given data types(%d) doesn't equal to the " + "number of sub tensor.", + multiple_lod_level.size(), GetTensorDescNum()); + switch (desc_.type()) { + case proto::VarDesc::READER: { + size_t i = 0; + for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + lod_tensor.set_lod_level(multiple_lod_level[i++]); + } + } break; + default: + PADDLE_THROW( + "Setting 'lod_levels' is not supported by the type of var %s.", + this->Name()); } } @@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.tensor_array().lod_level(); default: - PADDLE_THROW("Tensor type=%d does not support LoDLevel", - desc_.tensor_array().lod_level()); + PADDLE_THROW( + "Getting 'lod_level' is not supported by the type of var %s.", + this->Name()); + } +} + +std::vector VarDesc::GetLoDLevels() const { + std::vector res; + switch (desc_.type()) { + case proto::VarDesc::READER: + res.reserve(desc_.reader().lod_tensor_size()); + for (auto &lod_tensor : desc_.reader().lod_tensor()) { + res.push_back(lod_tensor.lod_level()); + } + return res; + break; + default: + PADDLE_THROW( + "Getting 'lod_levels' is not supported by the type of var %s.", + this->Name()); } } const proto::TensorDesc &VarDesc::tensor_desc() const { - PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type"); + PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); switch (desc_.type()) { case proto::VarDesc::SELECTED_ROWS: return desc_.selected_rows(); @@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.tensor_array().tensor(); default: - PADDLE_THROW("The type of var %s is unsupported.", this->Name()); + PADDLE_THROW( + "Getting 'tensor_desc' is not supported by the type of var %s.", + this->Name()); + } +} + +std::vector VarDesc::tensor_descs() const { + PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_.type()) { + case proto::VarDesc::READER: + for (const auto &lod_tensor : desc_.reader().lod_tensor()) { + res.push_back(lod_tensor.tensor()); + } + return res; + default: + PADDLE_THROW( + "Getting 'tensor_descs' is not supported by the type of var " + "%s.", + this->Name()); } } proto::TensorDesc *VarDesc::mutable_tensor_desc() { - PADDLE_ENFORCE(desc_.has_type(), - "invoke MutableTensorDesc must after set type"); + PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); switch (desc_.type()) { case proto::VarDesc::SELECTED_ROWS: return desc_.mutable_selected_rows(); @@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { case proto::VarDesc::LOD_TENSOR_ARRAY: return desc_.mutable_tensor_array()->mutable_tensor(); default: - PADDLE_THROW("Unexpected branch."); + PADDLE_THROW( + "Getting 'mutable_tensor_desc' is not supported by the type of var " + "%s.", + this->Name()); } } + +std::vector VarDesc::mutable_tensor_descs() { + PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_.type()) { + case proto::VarDesc::READER: + for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) { + res.push_back(lod_tensor.mutable_tensor()); + } + return res; + default: + PADDLE_THROW( + "Getting 'tensor_descs' is not supported by the type of var " + "%s.", + this->Name()); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 9316b14bb69..862b9a5d807 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -68,18 +68,34 @@ class VarDesc { void SetName(std::string name) { desc_.set_name(name); } + void SetTensorDescNum(size_t num); + + size_t GetTensorDescNum() const; + void SetShape(const std::vector &dims); + void SetShapes(const std::vector> &multiple_dims); + + std::vector GetShape() const; + + std::vector> GetShapes() const; + void SetDataType(proto::DataType data_type); - std::vector Shape() const; + void SetDataTypes(const std::vector &multiple_data_type); proto::DataType GetDataType() const; + std::vector GetDataTypes() const; + void SetLoDLevel(int32_t lod_level); + void SetLoDLevels(const std::vector &multiple_lod_level); + int32_t GetLoDLevel() const; + std::vector GetLoDLevels() const; + proto::VarDesc::VarType GetType() const; void SetType(proto::VarDesc::VarType type); @@ -90,7 +106,9 @@ class VarDesc { private: const proto::TensorDesc &tensor_desc() const; + std::vector tensor_descs() const; proto::TensorDesc *mutable_tensor_desc(); + std::vector mutable_tensor_descs(); proto::VarDesc desc_; }; diff --git a/paddle/inference/io.cc b/paddle/inference/io.cc index 60ad7af1c0a..1ed14b69c83 100644 --- a/paddle/inference/io.cc +++ b/paddle/inference/io.cc @@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor, VLOG(3) << "parameter's name: " << var->Name(); framework::VarDesc* new_var = load_block->Var(var->Name()); - new_var->SetShape(var->Shape()); + new_var->SetShape(var->GetShape()); new_var->SetDataType(var->GetDataType()); new_var->SetType(var->GetType()); new_var->SetLoDLevel(var->GetLoDLevel()); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 371d6119d4a..0f1953abe08 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) { py::return_value_policy::reference) .def("set_name", &VarDesc::SetName) .def("set_shape", &VarDesc::SetShape) + .def("set_shapes", &VarDesc::SetShapes) .def("set_dtype", &VarDesc::SetDataType) - .def("shape", &VarDesc::Shape, py::return_value_policy::reference) + .def("set_dtypes", &VarDesc::SetDataTypes) + .def("set_tensor_num", &VarDesc::SetTensorDescNum) + .def("tensor_num", &VarDesc::GetTensorDescNum) + .def("shape", &VarDesc::GetShape, py::return_value_policy::reference) + .def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference) .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) + .def("dtypes", &VarDesc::GetDataTypes, py::return_value_policy::reference) .def("lod_level", &VarDesc::GetLoDLevel) + .def("lod_levels", &VarDesc::GetLoDLevels, + py::return_value_policy::reference) .def("set_lod_level", &VarDesc::SetLoDLevel) + .def("set_lod_levels", &VarDesc::SetLoDLevels) .def("type", &VarDesc::GetType) .def("set_type", &VarDesc::SetType) .def("serialize_to_string", SerializeMessage) @@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) { .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) - .value("PLACE_LIST", proto::VarDesc::PLACE_LIST); + .value("PLACE_LIST", proto::VarDesc::PLACE_LIST) + .value("READER", proto::VarDesc::READER); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/fluid/tests/test_protobuf_descs.py b/python/paddle/v2/fluid/tests/test_protobuf_descs.py index 9034b2f4ef1..ac6de68b5fe 100644 --- a/python/paddle/v2/fluid/tests/test_protobuf_descs.py +++ b/python/paddle/v2/fluid/tests/test_protobuf_descs.py @@ -115,6 +115,20 @@ class TestVarDesc(unittest.TestCase): self.assertEqual(src_shape, res_shape) self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) + def test_multiple_shape(self): + program_desc = core.ProgramDesc() + block = program_desc.block(0) + var = block.var('my_reader') + var.set_type(core.VarDesc.VarType.READER) + var.set_tensor_num(3) + src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]] + var.set_shapes(src_shapes) + #import pdb + # pdb.set_trace() + res_shapes = var.shapes() + self.assertEqual(src_shapes, res_shapes) + self.assertEqual(core.VarDesc.VarType.READER, var.type()) + def test_dtype(self): program_desc = core.ProgramDesc() block = program_desc.block(0) @@ -124,6 +138,30 @@ class TestVarDesc(unittest.TestCase): self.assertEqual(core.DataType.INT32, var.dtype()) self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type()) + def test_multiple_dtype(self): + program_desc = core.ProgramDesc() + block = program_desc.block(0) + var = block.var('my_reader') + var.set_type(core.VarDesc.VarType.READER) + var.set_tensor_num(3) + src_types = [ + core.DataType.INT32, core.DataType.FP64, core.DataType.FP32 + ] + var.set_dtypes(src_types) + self.assertEqual(src_types, var.dtypes()) + self.assertEqual(core.VarDesc.VarType.READER, var.type()) + + def test_multiple_lod_level(self): + program_desc = core.ProgramDesc() + block = program_desc.block(0) + var = block.var('my_reader') + var.set_type(core.VarDesc.VarType.READER) + var.set_tensor_num(3) + src_types = [3, 1, 2] + var.set_lod_levels(src_types) + self.assertEqual(src_types, var.lod_levels()) + self.assertEqual(core.VarDesc.VarType.READER, var.type()) + class TestBlockDesc(unittest.TestCase): def test_add_var(self): -- GitLab