未验证 提交 8ca5206b 编写于 作者: Z zmx 提交者: GitHub

fix SerializeSelectedRows (#36543)

* bug fix for  DeserializeSelectedRows. test=develop

* fix bug for SerializeSelectedRows. test=develop

* update. test=develop
上级 6524fa8d
......@@ -138,23 +138,11 @@ void SerializeSelectedRows(framework::Variable* var,
var_data->clear();
var_data->resize(rows->size() * sizeof(int64_t));
char* data_ptr = const_cast<char*>(var_data->data());
if (platform::is_cpu_place(tensor->place())) {
memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t));
} else {
#ifdef PADDLE_WITH_CUDA
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), data_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
&(*rows)[0], rows->size() * sizeof(int64_t), stream);
#endif
}
memcpy(data_ptr, &((*rows)[0]), rows->size() * sizeof(int64_t));
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
for (auto& dim : framework::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册