提交 170ac721 编写于 作者: K Kexin Zhao 提交者: Kexin Zhao

remove unnecessary tensor copy in save op

上级 8b1918f5
...@@ -46,19 +46,6 @@ class LoadOp : public framework::OperatorBase { ...@@ -46,19 +46,6 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, *dev_ctx); DeserializeFromStream(fin, tensor, *dev_ctx);
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
framework::LoDTensor cpu_tensor;
cpu_tensor.ShareDataWith(*tensor);
cpu_tensor.set_lod(tensor->lod());
// reset tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册