提交 0a7c7e97 编写于 作者: Q Qiao Longfei

test zero output of split_selected_rows_op

test=develop
上级 abf14028
......@@ -298,6 +298,8 @@ PYBIND11_MODULE(core, m) {
.def("get_tensor",
[](SelectedRows &self) { return self.mutable_value(); },
py::return_value_policy::reference)
.def("numel",
[](SelectedRows &self) -> int64_t { return self.value().numel(); })
.def("set_height", &SelectedRows::set_height)
.def("height", &SelectedRows::height)
.def("set_rows",
......
......@@ -63,6 +63,7 @@ class TestSpliteSelectedRows(unittest.TestCase):
# expected output selected rows
expected_out0_rows = [0, 4]
expected_out1_rows = [0, 2]
expected_out2_rows = []
expected_out4_rows = [0]
op = Operator(
......@@ -75,6 +76,7 @@ class TestSpliteSelectedRows(unittest.TestCase):
self.assertEqual(outs[0].rows(), expected_out0_rows)
self.assertEqual(outs[1].rows(), expected_out1_rows)
self.assertEqual(outs[2].rows(), expected_out2_rows)
self.assertEqual(outs[4].rows(), expected_out4_rows)
self.assertEqual(outs[0].height(), height_sections[0])
......@@ -84,6 +86,9 @@ class TestSpliteSelectedRows(unittest.TestCase):
self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1])
self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1])
self.assertEqual(outs[2].numel(), 0)
self.assertEqual(outs[3].numel(), 0)
def check_grad_with_place(self, place):
scope = core.Scope()
height = 10
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册