提交 7dabee27 编写于 作者: F fengjiayi

Add type Reader for VarDesc

Add a new type `Reader` for `VarDesc`, which can holds more than one
LoDTensor.
上级 71bd0dfa
...@@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward( ...@@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward(
auto root_block = program_desc.MutableBlock(root_block_idx); auto root_block = program_desc.MutableBlock(root_block_idx);
std::string fill_one_op_out = GradVarName(target.Name()); std::string fill_one_op_out = GradVarName(target.Name());
bool is_scalar = target.Shape() == std::vector<int64_t>{1}; bool is_scalar = target.GetShape() == std::vector<int64_t>{1};
PADDLE_ENFORCE(is_scalar, "target should be scalar"); PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name() VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType(); << " data_type=" << target.GetDataType();
...@@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward( ...@@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward(
auto var = root_block->Var(fill_one_op_out); auto var = root_block->Var(fill_one_op_out);
var->SetDataType(target.GetDataType()); var->SetDataType(target.GetDataType());
var->SetShape(target.Shape()); var->SetShape(target.GetShape());
auto& target_grad = retv[target.Name()]; auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out; target_grad.name_ = fill_one_op_out;
target_grad.block_idx_ = root_block_idx; target_grad.block_idx_ = root_block_idx;
......
...@@ -116,6 +116,8 @@ message LoDTensorArrayDesc { ...@@ -116,6 +116,8 @@ message LoDTensorArrayDesc {
optional int32 lod_level = 2 [ default = 0 ]; optional int32 lod_level = 2 [ default = 0 ];
} }
message Reader { repeated LoDTensorDesc lod_tensor = 1; }
message VarDesc { message VarDesc {
enum VarType { enum VarType {
LOD_TENSOR = 1; LOD_TENSOR = 1;
...@@ -126,13 +128,15 @@ message VarDesc { ...@@ -126,13 +128,15 @@ message VarDesc {
LOD_RANK_TABLE = 6; LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7; LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8; PLACE_LIST = 8;
READER = 9;
} }
required string name = 1; required string name = 1;
required VarType type = 2; required VarType type = 2;
optional LoDTensorDesc lod_tensor = 3; optional bool persistable = 3 [ default = false ];
optional TensorDesc selected_rows = 4; optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6; optional LoDTensorArrayDesc tensor_array = 6;
optional bool persistable = 5 [ default = false ]; optional Reader reader = 7;
} }
message BlockDesc { message BlockDesc {
......
...@@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const { ...@@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
try { try {
auto shape = var->Shape(); auto shape = var->GetShape();
if (shape.empty()) { if (shape.empty()) {
return framework::make_ddim({0UL}); return framework::make_ddim({0UL});
} else { } else {
return framework::make_ddim(var->Shape()); return framework::make_ddim(var->GetShape());
} }
} catch (...) { } catch (...) {
VLOG(5) << "GetDim of variable " << name << " error"; VLOG(5) << "GetDim of variable " << name << " error";
......
...@@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_NE(copy, var_before); ASSERT_NE(copy, var_before);
ASSERT_EQ(copy->Name(), var_before->Name()); ASSERT_EQ(copy->Name(), var_before->Name());
ASSERT_EQ(copy->GetType(), var_before->GetType()); ASSERT_EQ(copy->GetType(), var_before->GetType());
ASSERT_EQ(copy->Shape(), var_before->Shape()); ASSERT_EQ(copy->GetShape(), var_before->GetShape());
ASSERT_EQ(copy->Proto()->SerializeAsString(), ASSERT_EQ(copy->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString()); var_before->Proto()->SerializeAsString());
}; };
...@@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ASSERT_NE(restored, var_before); ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name()); ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType()); ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape()); ASSERT_EQ(restored->GetShape(), var_before->GetShape());
ASSERT_EQ(restored->Proto()->SerializeAsString(), ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString()); var_before->Proto()->SerializeAsString());
}; };
......
...@@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) { ...@@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) {
case proto::VarDesc::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add();
}
return;
} break;
default:
PADDLE_THROW(
"Setting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}
size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) {
case proto::VarDesc::READER:
return desc_.reader().lod_tensor_size();
break;
default:
PADDLE_THROW(
"Getting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}
void VarDesc::SetShapes(
const std::vector<const std::vector<int64_t>> &multiple_dims) {
PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(),
"The number of given shapes(%d) doesn't equal to the "
"number of sub tensor.",
multiple_dims.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
}
std::vector<int64_t> VarDesc::GetShape() const {
return RepeatedToVector(tensor_desc().dims());
}
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(RepeatedToVector(tensor_desc.dims()));
}
return res;
}
void VarDesc::SetDataType(proto::DataType data_type) { void VarDesc::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
} }
std::vector<int64_t> VarDesc::Shape() const { void VarDesc::SetDataTypes(
return RepeatedToVector(tensor_desc().dims()); const std::vector<proto::DataType> &multiple_data_type) {
PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_data_type.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
} }
proto::DataType VarDesc::GetDataType() const { proto::DataType VarDesc::GetDataType() const {
return tensor_desc().data_type(); return tensor_desc().data_type();
} }
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
}
return res;
}
void VarDesc::SetLoDLevel(int32_t lod_level) { void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
...@@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { ...@@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
desc_.mutable_tensor_array()->set_lod_level(lod_level); desc_.mutable_tensor_array()->set_lod_level(lod_level);
break; break;
default: default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel", PADDLE_THROW(
desc_.tensor_array().lod_level()); "Setting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_lod_level.size(), GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER: {
size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]);
}
} break;
default:
PADDLE_THROW(
"Setting 'lod_levels' is not supported by the type of var %s.",
this->Name());
} }
} }
...@@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const { ...@@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const {
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level(); return desc_.tensor_array().lod_level();
default: default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel", PADDLE_THROW(
desc_.tensor_array().lod_level()); "Getting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}
std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res;
switch (desc_.type()) {
case proto::VarDesc::READER:
res.reserve(desc_.reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level());
}
return res;
break;
default:
PADDLE_THROW(
"Getting 'lod_levels' is not supported by the type of var %s.",
this->Name());
} }
} }
const proto::TensorDesc &VarDesc::tensor_desc() const { const proto::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type"); PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS: case proto::VarDesc::SELECTED_ROWS:
return desc_.selected_rows(); return desc_.selected_rows();
...@@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const { ...@@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor(); return desc_.tensor_array().tensor();
default: default:
PADDLE_THROW("The type of var %s is unsupported.", this->Name()); PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.",
this->Name());
}
}
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
} }
} }
proto::TensorDesc *VarDesc::mutable_tensor_desc() { proto::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
"invoke MutableTensorDesc must after set type");
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS: case proto::VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows(); return desc_.mutable_selected_rows();
...@@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() { ...@@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
case proto::VarDesc::LOD_TENSOR_ARRAY: case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.mutable_tensor_array()->mutable_tensor(); return desc_.mutable_tensor_array()->mutable_tensor();
default: default:
PADDLE_THROW("Unexpected branch."); PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var "
"%s.",
this->Name());
} }
} }
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -68,18 +68,34 @@ class VarDesc { ...@@ -68,18 +68,34 @@ class VarDesc {
void SetName(std::string name) { desc_.set_name(name); } void SetName(std::string name) { desc_.set_name(name); }
void SetTensorDescNum(size_t num);
size_t GetTensorDescNum() const;
void SetShape(const std::vector<int64_t> &dims); void SetShape(const std::vector<int64_t> &dims);
void SetShapes(const std::vector<const std::vector<int64_t>> &multiple_dims);
std::vector<int64_t> GetShape() const;
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(proto::DataType data_type); void SetDataType(proto::DataType data_type);
std::vector<int64_t> Shape() const; void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);
proto::DataType GetDataType() const; proto::DataType GetDataType() const;
std::vector<proto::DataType> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level); void SetLoDLevel(int32_t lod_level);
void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);
int32_t GetLoDLevel() const; int32_t GetLoDLevel() const;
std::vector<int32_t> GetLoDLevels() const;
proto::VarDesc::VarType GetType() const; proto::VarDesc::VarType GetType() const;
void SetType(proto::VarDesc::VarType type); void SetType(proto::VarDesc::VarType type);
...@@ -90,7 +106,9 @@ class VarDesc { ...@@ -90,7 +106,9 @@ class VarDesc {
private: private:
const proto::TensorDesc &tensor_desc() const; const proto::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc(); proto::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs();
proto::VarDesc desc_; proto::VarDesc desc_;
}; };
......
...@@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor, ...@@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor,
VLOG(3) << "parameter's name: " << var->Name(); VLOG(3) << "parameter's name: " << var->Name();
framework::VarDesc* new_var = load_block->Var(var->Name()); framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->Shape()); new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType()); new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType()); new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel()); new_var->SetLoDLevel(var->GetLoDLevel());
......
...@@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) { ...@@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) {
py::return_value_policy::reference) py::return_value_policy::reference)
.def("set_name", &VarDesc::SetName) .def("set_name", &VarDesc::SetName)
.def("set_shape", &VarDesc::SetShape) .def("set_shape", &VarDesc::SetShape)
.def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType) .def("set_dtype", &VarDesc::SetDataType)
.def("shape", &VarDesc::Shape, py::return_value_policy::reference) .def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_tensor_num", &VarDesc::SetTensorDescNum)
.def("tensor_num", &VarDesc::GetTensorDescNum)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
.def("dtypes", &VarDesc::GetDataTypes, py::return_value_policy::reference)
.def("lod_level", &VarDesc::GetLoDLevel) .def("lod_level", &VarDesc::GetLoDLevel)
.def("lod_levels", &VarDesc::GetLoDLevels,
py::return_value_policy::reference)
.def("set_lod_level", &VarDesc::SetLoDLevel) .def("set_lod_level", &VarDesc::SetLoDLevel)
.def("set_lod_levels", &VarDesc::SetLoDLevels)
.def("type", &VarDesc::GetType) .def("type", &VarDesc::GetType)
.def("set_type", &VarDesc::SetType) .def("set_type", &VarDesc::SetType)
.def("serialize_to_string", SerializeMessage<VarDesc>) .def("serialize_to_string", SerializeMessage<VarDesc>)
...@@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) { ...@@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) {
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY) .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST); .value("PLACE_LIST", proto::VarDesc::PLACE_LIST)
.value("READER", proto::VarDesc::READER);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -115,6 +115,20 @@ class TestVarDesc(unittest.TestCase): ...@@ -115,6 +115,20 @@ class TestVarDesc(unittest.TestCase):
self.assertEqual(src_shape, res_shape) self.assertEqual(src_shape, res_shape)
self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
def test_multiple_shape(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
var.set_tensor_num(3)
src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]]
var.set_shapes(src_shapes)
#import pdb
# pdb.set_trace()
res_shapes = var.shapes()
self.assertEqual(src_shapes, res_shapes)
self.assertEqual(core.VarDesc.VarType.READER, var.type())
def test_dtype(self): def test_dtype(self):
program_desc = core.ProgramDesc() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
...@@ -124,6 +138,30 @@ class TestVarDesc(unittest.TestCase): ...@@ -124,6 +138,30 @@ class TestVarDesc(unittest.TestCase):
self.assertEqual(core.DataType.INT32, var.dtype()) self.assertEqual(core.DataType.INT32, var.dtype())
self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type()) self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
def test_multiple_dtype(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
var.set_tensor_num(3)
src_types = [
core.DataType.INT32, core.DataType.FP64, core.DataType.FP32
]
var.set_dtypes(src_types)
self.assertEqual(src_types, var.dtypes())
self.assertEqual(core.VarDesc.VarType.READER, var.type())
def test_multiple_lod_level(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
var.set_tensor_num(3)
src_types = [3, 1, 2]
var.set_lod_levels(src_types)
self.assertEqual(src_types, var.lod_levels())
self.assertEqual(core.VarDesc.VarType.READER, var.type())
class TestBlockDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase):
def test_add_var(self): def test_add_var(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册