提交 bea82122 编写于 作者: D dangqingqing

Expose LoDTensor to pybind.

上级 b59f3018
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
......@@ -54,6 +55,7 @@ namespace paddle {
namespace framework {
using Tensor = framework::Tensor;
using LODTensor = framework::LODTensor;
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
......@@ -113,6 +115,25 @@ PYBIND11_PLUGIN(core) {
return self.data<float>()[offset];
});
py::class_<LODTensor>(m, "LODTensor", R"DOC(LOD(Leval of Ddetails) Tensor.
The tensor and LOD info should be created before creating the LODTensor, then
call the set_tensor and set_lod functions to set them.
)DOC")
.def("set_tensor",
[](LODTensor &self, Tensor *tensor) { self.set_tensor(tensor); })
.def("set_lod",
[](LODTensor &self, std::vector<std::vector<size_t>> &lod) {
self.set_lod(lod);
})
.def("get_tensor",
[](LODTensor &self) -> Tensor & { return self.tensor(); },
py::return_value_policy::reference)
.def("get_lod", [](LODTensor &self) -> std::vector<std::vector<size_t>> {
return self.lod();
});
py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle.
......@@ -124,6 +145,11 @@ All parameter, weight, gradient are variables in Paddle.
.def("get_tensor",
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference)
.def("get_lod_tensor",
[](Variable &self) -> LODTensor * {
return self.GetMutable<LODTensor>();
},
py::return_value_policy::reference)
.def("get_net",
[](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册