提交 c65bdd95 编写于 作者: Q qijun

fix SelectedRows rows() method gpu runtime error

上级 97069927
...@@ -153,7 +153,17 @@ PYBIND11_PLUGIN(core) { ...@@ -153,7 +153,17 @@ PYBIND11_PLUGIN(core) {
.def("set_height", &SelectedRows::set_height) .def("set_height", &SelectedRows::set_height)
.def("height", &SelectedRows::height) .def("height", &SelectedRows::height)
.def("set_rows", &SelectedRows::set_rows) .def("set_rows", &SelectedRows::set_rows)
.def("rows", &SelectedRows::rows, py::return_value_policy::reference); .def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA
return self.rows();
#else
auto rows = self.rows();
std::vector<int64_t> new_rows;
new_rows.reserve(rows.size());
std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
return new_rows;
#endif
});
py::class_<Variable>(m, "Variable", R"DOC(Variable Class. py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册