diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 90b995decb9a6b5b684efc44abe8fc280ed73adc..e1f7bc8672eebac3c311efb30f2a2630ec3a4587 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -328,7 +328,31 @@ All parameter, weight, gradient are variables in Paddle. [](BlockDesc &self, int32_t idx) { self.set_parent_idx(idx); }) .def("parent", [](BlockDesc &self) { return self.parent_idx(); }); - py::class_(m, "VarDesc", ""); + py::class_(m, "VarDesc", "") + .def(py::init<>()) + .def("set_name", + [](VarDesc &self, const std::string &name) { self.set_name(name); }) + .def("set_shape", + [](VarDesc &self, const std::vector &dims) { + LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); + for (const int64_t &i : dims) { + lod_tensor_desc->add_dims(i); + } + }) + .def("set_data_type", + [](VarDesc &self, int type_id) { + LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); + lod_tensor_desc->set_data_type(static_cast(type_id)); + }) + .def("shape", [](VarDesc &self) { + const LoDTensorDesc &lod_tensor_desc = self.lod_tensor(); + int rank = lod_tensor_desc.dims_size(); + std::vector res(rank); + for (int i = 0; i < rank; ++i) { + res[i] = lod_tensor_desc.dims(i); + } + return res; + }); py::class_(m, "OpDesc", "");