未验证 提交 1213e283 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #14820 from jacquesqiao/fix-split-selected-rows

split selected rows op should always init output selected rows
......@@ -72,10 +72,11 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(height_sections[i]);
auto dims = x->GetCompleteDims();
dims[0] = rows_idx.size();
outs[i]->mutable_value()->mutable_data<T>(dims, x->place());
outs[i]->mutable_rows()->clear();
if (rows_idx.size() > 0) {
auto dims = x->GetCompleteDims();
dims[0] = rows_idx.size();
outs[i]->mutable_value()->mutable_data<T>(dims, x->place());
for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
}
......@@ -98,6 +99,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
}
}
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(),
"rows should has the same size with tensor dim 0");
}
}
};
......
......@@ -336,6 +336,8 @@ PYBIND11_MODULE(core, m) {
.def("get_tensor",
[](SelectedRows &self) { return self.mutable_value(); },
py::return_value_policy::reference)
.def("numel",
[](SelectedRows &self) -> int64_t { return self.value().numel(); })
.def("set_height", &SelectedRows::set_height)
.def("height", &SelectedRows::height)
.def("set_rows",
......
......@@ -63,6 +63,7 @@ class TestSpliteSelectedRows(unittest.TestCase):
# expected output selected rows
expected_out0_rows = [0, 4]
expected_out1_rows = [0, 2]
expected_out2_rows = []
expected_out4_rows = [0]
op = Operator(
......@@ -75,6 +76,7 @@ class TestSpliteSelectedRows(unittest.TestCase):
self.assertEqual(outs[0].rows(), expected_out0_rows)
self.assertEqual(outs[1].rows(), expected_out1_rows)
self.assertEqual(outs[2].rows(), expected_out2_rows)
self.assertEqual(outs[4].rows(), expected_out4_rows)
self.assertEqual(outs[0].height(), height_sections[0])
......@@ -84,6 +86,9 @@ class TestSpliteSelectedRows(unittest.TestCase):
self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1])
self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1])
self.assertEqual(outs[2].numel(), 0)
self.assertEqual(outs[3].numel(), 0)
def check_grad_with_place(self, place):
scope = core.Scope()
height = 10
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册