提交 7dabee27 编写于 作者: F fengjiayi

Add type Reader for VarDesc

Add a new type `Reader` for `VarDesc`, which can holds more than one
LoDTensor.
上级 71bd0dfa
......@@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward(
auto root_block = program_desc.MutableBlock(root_block_idx);
std::string fill_one_op_out = GradVarName(target.Name());
bool is_scalar = target.Shape() == std::vector<int64_t>{1};
bool is_scalar = target.GetShape() == std::vector<int64_t>{1};
PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType();
......@@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward(
auto var = root_block->Var(fill_one_op_out);
var->SetDataType(target.GetDataType());
var->SetShape(target.Shape());
var->SetShape(target.GetShape());
auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out;
target_grad.block_idx_ = root_block_idx;
......
......@@ -116,6 +116,8 @@ message LoDTensorArrayDesc {
optional int32 lod_level = 2 [ default = 0 ];
}
message Reader { repeated LoDTensorDesc lod_tensor = 1; }
message VarDesc {
enum VarType {
LOD_TENSOR = 1;
......@@ -126,13 +128,15 @@ message VarDesc {
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
}
required string name = 1;
required VarType type = 2;
optional LoDTensorDesc lod_tensor = 3;
optional TensorDesc selected_rows = 4;
optional bool persistable = 3 [ default = false ];
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional bool persistable = 5 [ default = false ];
optional Reader reader = 7;
}
message BlockDesc {
......
......@@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
try {
auto shape = var->Shape();
auto shape = var->GetShape();
if (shape.empty()) {
return framework::make_ddim({0UL});
} else {
return framework::make_ddim(var->Shape());
return framework::make_ddim(var->GetShape());
}
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
......
......@@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_NE(copy, var_before);
ASSERT_EQ(copy->Name(), var_before->Name());
ASSERT_EQ(copy->GetType(), var_before->GetType());
ASSERT_EQ(copy->Shape(), var_before->Shape());
ASSERT_EQ(copy->GetShape(), var_before->GetShape());
ASSERT_EQ(copy->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
......@@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape());
ASSERT_EQ(restored->GetShape(), var_before->GetShape());
ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
......
......@@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}
void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) {
case proto::VarDesc::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add();
}
return;
} break;
default:
PADDLE_THROW(
"Setting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}
size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) {
case proto::VarDesc::READER:
return desc_.reader().lod_tensor_size();
break;
default:
PADDLE_THROW(
"Getting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}
void VarDesc::SetShapes(
const std::vector<const std::vector<int64_t>> &multiple_dims) {
PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(),
"The number of given shapes(%d) doesn't equal to the "
"number of sub tensor.",
multiple_dims.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
}
std::vector<int64_t> VarDesc::GetShape() const {
return RepeatedToVector(tensor_desc().dims());
}
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(RepeatedToVector(tensor_desc.dims()));
}
return res;
}
void VarDesc::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}
std::vector<int64_t> VarDesc::Shape() const {
return RepeatedToVector(tensor_desc().dims());
void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) {
PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_data_type.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
}
proto::DataType VarDesc::GetDataType() const {
return tensor_desc().data_type();
}
std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
}
return res;
}
void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR:
......@@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
desc_.mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
PADDLE_THROW(
"Setting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_lod_level.size(), GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER: {
size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]);
}
} break;
default:
PADDLE_THROW(
"Setting 'lod_levels' is not supported by the type of var %s.",
this->Name());
}
}
......@@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level();
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
PADDLE_THROW(
"Getting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}
std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res;
switch (desc_.type()) {
case proto::VarDesc::READER:
res.reserve(desc_.reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level());
}
return res;
break;
default:
PADDLE_THROW(
"Getting 'lod_levels' is not supported by the type of var %s.",
this->Name());
}
}
const proto::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.selected_rows();
......@@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor();
default:
PADDLE_THROW("The type of var %s is unsupported.", this->Name());
PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.",
this->Name());
}
}
std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
}
}
proto::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type");
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows();
......@@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.mutable_tensor_array()->mutable_tensor();
default:
PADDLE_THROW("Unexpected branch.");
PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}
std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
}
}
} // namespace framework
} // namespace paddle
......@@ -68,18 +68,34 @@ class VarDesc {
void SetName(std::string name) { desc_.set_name(name); }
void SetTensorDescNum(size_t num);
size_t GetTensorDescNum() const;
void SetShape(const std::vector<int64_t> &dims);
void SetShapes(const std::vector<const std::vector<int64_t>> &multiple_dims);
std::vector<int64_t> GetShape() const;
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(proto::DataType data_type);
std::vector<int64_t> Shape() const;
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);
proto::DataType GetDataType() const;
std::vector<proto::DataType> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level);
void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);
int32_t GetLoDLevel() const;
std::vector<int32_t> GetLoDLevels() const;
proto::VarDesc::VarType GetType() const;
void SetType(proto::VarDesc::VarType type);
......@@ -90,7 +106,9 @@ class VarDesc {
private:
const proto::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs();
proto::VarDesc desc_;
};
......
......@@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor,
VLOG(3) << "parameter's name: " << var->Name();
framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->Shape());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
......
......@@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) {
py::return_value_policy::reference)
.def("set_name", &VarDesc::SetName)
.def("set_shape", &VarDesc::SetShape)
.def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType)
.def("shape", &VarDesc::Shape, py::return_value_policy::reference)
.def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_tensor_num", &VarDesc::SetTensorDescNum)
.def("tensor_num", &VarDesc::GetTensorDescNum)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
.def("dtypes", &VarDesc::GetDataTypes, py::return_value_policy::reference)
.def("lod_level", &VarDesc::GetLoDLevel)
.def("lod_levels", &VarDesc::GetLoDLevels,
py::return_value_policy::reference)
.def("set_lod_level", &VarDesc::SetLoDLevel)
.def("set_lod_levels", &VarDesc::SetLoDLevels)
.def("type", &VarDesc::GetType)
.def("set_type", &VarDesc::SetType)
.def("serialize_to_string", SerializeMessage<VarDesc>)
......@@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) {
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST);
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST)
.value("READER", proto::VarDesc::READER);
}
void BindOpDesc(py::module &m) {
......
......@@ -115,6 +115,20 @@ class TestVarDesc(unittest.TestCase):
self.assertEqual(src_shape, res_shape)
self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
def test_multiple_shape(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
var.set_tensor_num(3)
src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]]
var.set_shapes(src_shapes)
#import pdb
# pdb.set_trace()
res_shapes = var.shapes()
self.assertEqual(src_shapes, res_shapes)
self.assertEqual(core.VarDesc.VarType.READER, var.type())
def test_dtype(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
......@@ -124,6 +138,30 @@ class TestVarDesc(unittest.TestCase):
self.assertEqual(core.DataType.INT32, var.dtype())
self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
def test_multiple_dtype(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
var.set_tensor_num(3)
src_types = [
core.DataType.INT32, core.DataType.FP64, core.DataType.FP32
]
var.set_dtypes(src_types)
self.assertEqual(src_types, var.dtypes())
self.assertEqual(core.VarDesc.VarType.READER, var.type())
def test_multiple_lod_level(self):
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_reader')
var.set_type(core.VarDesc.VarType.READER)
var.set_tensor_num(3)
src_types = [3, 1, 2]
var.set_lod_levels(src_types)
self.assertEqual(src_types, var.lod_levels())
self.assertEqual(core.VarDesc.VarType.READER, var.type())
class TestBlockDesc(unittest.TestCase):
def test_add_var(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册