From 0a7c7e97afc20d26406de27a22c9cd4b7edad8b0 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 11 Dec 2018 10:45:16 +0800 Subject: [PATCH] test zero output of split_selected_rows_op test=develop --- paddle/fluid/pybind/pybind.cc | 2 ++ .../fluid/tests/unittests/test_split_selected_rows_op.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 58ef3da0b2..9d92529e4e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -298,6 +298,8 @@ PYBIND11_MODULE(core, m) { .def("get_tensor", [](SelectedRows &self) { return self.mutable_value(); }, py::return_value_policy::reference) + .def("numel", + [](SelectedRows &self) -> int64_t { return self.value().numel(); }) .def("set_height", &SelectedRows::set_height) .def("height", &SelectedRows::height) .def("set_rows", diff --git a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py index 50204b8a77..f8847e1570 100644 --- a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py @@ -63,6 +63,7 @@ class TestSpliteSelectedRows(unittest.TestCase): # expected output selected rows expected_out0_rows = [0, 4] expected_out1_rows = [0, 2] + expected_out2_rows = [] expected_out4_rows = [0] op = Operator( @@ -75,6 +76,7 @@ class TestSpliteSelectedRows(unittest.TestCase): self.assertEqual(outs[0].rows(), expected_out0_rows) self.assertEqual(outs[1].rows(), expected_out1_rows) + self.assertEqual(outs[2].rows(), expected_out2_rows) self.assertEqual(outs[4].rows(), expected_out4_rows) self.assertEqual(outs[0].height(), height_sections[0]) @@ -84,6 +86,9 @@ class TestSpliteSelectedRows(unittest.TestCase): self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1]) self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1]) + self.assertEqual(outs[2].numel(), 0) + self.assertEqual(outs[3].numel(), 0) + def check_grad_with_place(self, place): scope = core.Scope() height = 10 -- GitLab