From 219f46ae16fa3f3e74f28bcbf1b7f815b9b5ac92 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 16 Oct 2017 11:56:15 -0700 Subject: [PATCH] export SelectedRows to Python --- paddle/pybind/pybind.cc | 17 +++++++++ .../v2/framework/tests/test_selected_rows.py | 37 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 python/paddle/v2/framework/tests/test_selected_rows.py diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index afc80b25b18..23e76011c94 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/executor.h" #include "paddle/framework/lod_tensor.h" +#include "paddle/framework/selected_rows.h" #include "paddle/framework/tensor_array.h" #include "paddle/operators/cond_op.h" #include "paddle/operators/dynamic_recurrent_op.h" @@ -138,6 +139,22 @@ PYBIND11_PLUGIN(core) { #endif }); + py::class_(m, "SelectedRows") + .def("__init__", + [](SelectedRows &instance) { new (&instance) SelectedRows(); }) + .def("__init__", + [](SelectedRows &instance, const std::vector rows, + const int64_t &height) { + new (&instance) SelectedRows(rows, height); + }) + .def("get_tensor", + [](SelectedRows &self) { return self.mutable_value(); }, + py::return_value_policy::reference) + .def("set_height", &SelectedRows::set_height) + .def("height", &SelectedRows::height) + .def("set_rows", &SelectedRows::set_rows) + .def("rows", &SelectedRows::rows, py::return_value_policy::reference); + py::class_(m, "Variable", R"DOC(Variable Class. All parameter, weight, gradient are variables in Paddle. diff --git a/python/paddle/v2/framework/tests/test_selected_rows.py b/python/paddle/v2/framework/tests/test_selected_rows.py new file mode 100644 index 00000000000..661e8181795 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_selected_rows.py @@ -0,0 +1,37 @@ +import paddle.v2.framework.core as core +import unittest +import numpy as np + + +class TestSelectedRows(unittest.TestCase): + def test_selected_rows(self): + place = core.CPUPlace() + height = 10 + rows = [0, 4, 7] + row_numel = 10 + selcted_rows = core.SelectedRows(rows, row_numel) + np_array = np.ones((len(rows), height)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + tensor = selcted_rows.get_tensor() + tensor.set(np_array, place) + + # compare rows + self.assertEqual(0, selcted_rows.rows()[0]) + self.assertEqual(4, selcted_rows.rows()[1]) + self.assertEqual(7, selcted_rows.rows()[2]) + + # compare height + self.assertEqual(10, selcted_rows.height()) + + # compare tensor + self.assertAlmostEqual(2.0, + selcted_rows.get_tensor().get_float_element(0)) + self.assertAlmostEqual(1.0, + selcted_rows.get_tensor().get_float_element(1)) + self.assertAlmostEqual( + 4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8)) + + +if __name__ == "__main__": + unittest.main() -- GitLab