提交 cdc236cb 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #4841 from QiJune/pybind_selected_rows

export SelectedRows to Python
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor_array.h" #include "paddle/framework/tensor_array.h"
#include "paddle/operators/cond_op.h" #include "paddle/operators/cond_op.h"
#include "paddle/operators/dynamic_recurrent_op.h" #include "paddle/operators/dynamic_recurrent_op.h"
...@@ -139,6 +140,32 @@ PYBIND11_PLUGIN(core) { ...@@ -139,6 +140,32 @@ PYBIND11_PLUGIN(core) {
#endif #endif
}); });
py::class_<SelectedRows>(m, "SelectedRows")
.def("__init__",
[](SelectedRows &instance) { new (&instance) SelectedRows(); })
.def("__init__",
[](SelectedRows &instance, const std::vector<int64_t> rows,
const int64_t &height) {
new (&instance) SelectedRows(rows, height);
})
.def("get_tensor",
[](SelectedRows &self) { return self.mutable_value(); },
py::return_value_policy::reference)
.def("set_height", &SelectedRows::set_height)
.def("height", &SelectedRows::height)
.def("set_rows", &SelectedRows::set_rows)
.def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA
return self.rows();
#else
auto rows = self.rows();
std::vector<int64_t> new_rows;
new_rows.reserve(rows.size());
std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
return new_rows;
#endif
});
py::class_<Variable>(m, "Variable", R"DOC(Variable Class. py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle. All parameter, weight, gradient are variables in Paddle.
......
import paddle.v2.framework.core as core
import unittest
import numpy as np
class TestSelectedRows(unittest.TestCase):
def test_selected_rows(self):
place = core.CPUPlace()
height = 10
rows = [0, 4, 7]
row_numel = 10
selcted_rows = core.SelectedRows(rows, row_numel)
np_array = np.ones((len(rows), height)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
tensor = selcted_rows.get_tensor()
tensor.set(np_array, place)
# compare rows
self.assertEqual(0, selcted_rows.rows()[0])
self.assertEqual(4, selcted_rows.rows()[1])
self.assertEqual(7, selcted_rows.rows()[2])
# compare height
self.assertEqual(10, selcted_rows.height())
# compare tensor
self.assertAlmostEqual(2.0,
selcted_rows.get_tensor().get_float_element(0))
self.assertAlmostEqual(1.0,
selcted_rows.get_tensor().get_float_element(1))
self.assertAlmostEqual(
4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册