From eeb7c8ad795d6d7159d3659a2d41709653e2e347 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 22 Sep 2017 17:34:47 -0700 Subject: [PATCH] Compelete VarDescBind --- paddle/pybind/protobuf.cc | 44 ++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 5511841c8b5..126c2ce1c7d 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -46,12 +46,24 @@ class VarDescBind; class VarDescBind { public: - explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); } + explicit VarDescBind(const std::string &name) { desc_.set_name(name); } - VarDesc *Proto() { return &var_desc_; } + VarDesc *Proto() { return &desc_; } + + void SetShape(const vector &dims) { + VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); + } + + void SetDataType(int type_id) { + desc_.mutable_lod_tensor()->set_data_type(const_cast(type_id)); + } + + std::vector Shape() { + return RepeatedToVector(desc_.lod_tensor().dims()); + } private: - VarDesc var_desc_; + VarDesc desc_; }; class OpDescBind { @@ -217,27 +229,11 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { - py::class_(m, "VarDesc", ""); - // using namespace paddle::framework; // NOLINT - // py::class_(m, "VarDesc", "") - // .def(py::init<>()) - // .def("set_name", - // [](VarDesc &self, const std::string &name) { self.set_name(name); - // }) - // .def("set_shape", - // [](VarDesc &self, const std::vector &dims) { - // VectorToRepeated(dims, - // self.mutable_lod_tensor()->mutable_dims()); - // }) - // .def("set_data_type", - // [](VarDesc &self, int type_id) { - // LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); - // lod_tensor_desc->set_data_type(static_cast(type_id)); - // }) - // .def("shape", [](VarDesc &self) { - // const LoDTensorDesc &lod_tensor_desc = self.lod_tensor(); - // return RepeatedToVector(lod_tensor_desc.dims()); - // }); + py::class_(m, "VarDesc", "") + .def(py::init<>()) + .def("set_shape", VarDescBind::SetShape) + .def("set_data_type", VarDescBind::SetDataType) + .def("shape", VarDescBind::Shape); } void BindOpDesc(py::module &m) { -- GitLab