diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 008a9441db408e2db99ec86af645344f376d0b8f..fcae92ad99f7b393104ac04fd725ad3d43db04ad 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/framework/executor.h" #include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/lod_tensor.h" +#include "paddle/framework/selected_rows.h" #include "paddle/framework/tensor_array.h" #include "paddle/operators/cond_op.h" #include "paddle/operators/dynamic_recurrent_op.h" @@ -139,6 +140,32 @@ PYBIND11_PLUGIN(core) { #endif }); + py::class_(m, "SelectedRows") + .def("__init__", + [](SelectedRows &instance) { new (&instance) SelectedRows(); }) + .def("__init__", + [](SelectedRows &instance, const std::vector rows, + const int64_t &height) { + new (&instance) SelectedRows(rows, height); + }) + .def("get_tensor", + [](SelectedRows &self) { return self.mutable_value(); }, + py::return_value_policy::reference) + .def("set_height", &SelectedRows::set_height) + .def("height", &SelectedRows::height) + .def("set_rows", &SelectedRows::set_rows) + .def("rows", [](SelectedRows &self) { +#ifndef PADDLE_WITH_CUDA + return self.rows(); +#else + auto rows = self.rows(); + std::vector new_rows; + new_rows.reserve(rows.size()); + std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows)); + return new_rows; +#endif + }); + py::class_(m, "Variable", R"DOC(Variable Class. All parameter, weight, gradient are variables in Paddle. diff --git a/python/paddle/v2/framework/tests/test_selected_rows.py b/python/paddle/v2/framework/tests/test_selected_rows.py new file mode 100644 index 0000000000000000000000000000000000000000..661e81817951f5605ba3ca7fb0cc667074b1e37c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_selected_rows.py @@ -0,0 +1,37 @@ +import paddle.v2.framework.core as core +import unittest +import numpy as np + + +class TestSelectedRows(unittest.TestCase): + def test_selected_rows(self): + place = core.CPUPlace() + height = 10 + rows = [0, 4, 7] + row_numel = 10 + selcted_rows = core.SelectedRows(rows, row_numel) + np_array = np.ones((len(rows), height)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + tensor = selcted_rows.get_tensor() + tensor.set(np_array, place) + + # compare rows + self.assertEqual(0, selcted_rows.rows()[0]) + self.assertEqual(4, selcted_rows.rows()[1]) + self.assertEqual(7, selcted_rows.rows()[2]) + + # compare height + self.assertEqual(10, selcted_rows.height()) + + # compare tensor + self.assertAlmostEqual(2.0, + selcted_rows.get_tensor().get_float_element(0)) + self.assertAlmostEqual(1.0, + selcted_rows.get_tensor().get_float_element(1)) + self.assertAlmostEqual( + 4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8)) + + +if __name__ == "__main__": + unittest.main()