diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index cf83d4cec312ac16366d84f897e7dc4784596ae8..6fcfe6de25737b66a2ea6c1a438636f072a513bb 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -106,7 +106,7 @@ enum DataType { message LoDTensorDesc { required DataType data_type = 1; - repeated int32 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] optional int32 lod_level = 3 [ default = 0 ]; } diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 345bb02c86e3b88570df738c09ff847168929705..23c322ac3651fa3668a55bd1e2f6010ec24e2da6 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -47,12 +47,24 @@ class VarDescBind; class VarDescBind { public: - explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); } + explicit VarDescBind(const std::string &name) { desc_.set_name(name); } - VarDesc *Proto() { return &var_desc_; } + VarDesc *Proto() { return &desc_; } + + void SetShape(const std::vector &dims) { + VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); + } + + void SetDataType(int type_id) { + desc_.mutable_lod_tensor()->set_data_type(static_cast(type_id)); + } + + std::vector Shape() { + return RepeatedToVector(desc_.lod_tensor().dims()); + } private: - VarDesc var_desc_; + VarDesc desc_; }; class OpDescBind { @@ -170,7 +182,8 @@ public: int32_t Parent() const { return desc_->parent_idx(); } - VarDescBind *NewVar(const std::string &name) { + VarDescBind *NewVar(py::bytes name_bytes) { + std::string name = name_bytes; need_update_ = true; auto it = vars_.find(name); PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); @@ -303,32 +316,15 @@ void BindBlockDesc(py::module &m) { &BlockDescBind::AppendOp, py::return_value_policy::reference) .def("new_var", - [](BlockDesc &self) { return self.add_vars(); }, + &BlockDescBind::NewVar, py::return_value_policy::reference); } void BindVarDsec(py::module &m) { - py::class_(m, "VarDesc", ""); - // using namespace paddle::framework; // NOLINT - // py::class_(m, "VarDesc", "") - // .def(py::init<>()) - // .def("set_name", - // [](VarDesc &self, const std::string &name) { self.set_name(name); - // }) - // .def("set_shape", - // [](VarDesc &self, const std::vector &dims) { - // VectorToRepeated(dims, - // self.mutable_lod_tensor()->mutable_dims()); - // }) - // .def("set_data_type", - // [](VarDesc &self, int type_id) { - // LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); - // lod_tensor_desc->set_data_type(static_cast(type_id)); - // }) - // .def("shape", [](VarDesc &self) { - // const LoDTensorDesc &lod_tensor_desc = self.lod_tensor(); - // return RepeatedToVector(lod_tensor_desc.dims()); - // }); + py::class_(m, "VarDesc", "") + .def("set_shape", &VarDescBind::SetShape) + .def("set_data_type", &VarDescBind::SetDataType) + .def("shape", &VarDescBind::Shape); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 950a9363071b64c0fc24b17822f121f2c66207cb..2e96dcced5280fcc01bb2582bbedd9f28ffa608d 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -52,7 +52,7 @@ class TestVarDesc(unittest.TestCase): def test_shape(self): program_desc = core.ProgramDesc.instance() block = program_desc.root_block() - var = block.new_var() + var = block.new_var('my_var') src_shape = [3, 2, 10, 8] var.set_shape(src_shape) res_shape = var.shape()