提交 fa613206 编写于 作者: C chengduoZH

update

上级 4bfadcd1
...@@ -37,11 +37,17 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, ...@@ -37,11 +37,17 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
auto* dev_ctx = GetDeviceContext(in.place(), dst_place); auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
// the enforced checkings have been done in GetDeviceContext, so the
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
// slow, especially when the number of elements is little, for example,
// the elements of learning rate are one and it's CPU side.
// One solution is to use a CUDA kernel to complete the copy operation when
// the transforming is from CPU to GPU and the number of elements is little.
// But the embarrassment is that this solution this solution makes training
// slower.
TensorCopy(in, dst_place, *dev_ctx, out); TensorCopy(in, dst_place, *dev_ctx, out);
if (in.place().which() != dst_place.which()) {
dev_ctx->Wait(); dev_ctx->Wait();
}
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册