From bea82122dd3c66e3a4cd69939a7ac68f7cce9524 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 7 Sep 2017 15:29:42 +0800 Subject: [PATCH] Expose LoDTensor to pybind. --- paddle/pybind/pybind.cc | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index ba28b51ad..0b9d2697d 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/framework/backward.h" +#include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" @@ -54,6 +55,7 @@ namespace paddle { namespace framework { using Tensor = framework::Tensor; +using LODTensor = framework::LODTensor; static size_t UniqueIntegerGenerator() { static std::atomic generator; @@ -113,6 +115,25 @@ PYBIND11_PLUGIN(core) { return self.data()[offset]; }); + py::class_(m, "LODTensor", R"DOC(LOD(Leval of Ddetails) Tensor. + +The tensor and LOD info should be created before creating the LODTensor, then +call the set_tensor and set_lod functions to set them. + +)DOC") + .def("set_tensor", + [](LODTensor &self, Tensor *tensor) { self.set_tensor(tensor); }) + .def("set_lod", + [](LODTensor &self, std::vector> &lod) { + self.set_lod(lod); + }) + .def("get_tensor", + [](LODTensor &self) -> Tensor & { return self.tensor(); }, + py::return_value_policy::reference) + .def("get_lod", [](LODTensor &self) -> std::vector> { + return self.lod(); + }); + py::class_(m, "Variable", R"DOC(Variable Class. All parameter, weight, gradient are variables in Paddle. @@ -124,6 +145,11 @@ All parameter, weight, gradient are variables in Paddle. .def("get_tensor", [](Variable &self) -> Tensor * { return self.GetMutable(); }, py::return_value_policy::reference) + .def("get_lod_tensor", + [](Variable &self) -> LODTensor * { + return self.GetMutable(); + }, + py::return_value_policy::reference) .def("get_net", [](Variable &self) -> operators::NetOp * { return self.GetMutable(); -- GitLab