提交 b3115fb0 编写于 作者: L liaogang

Add SetDeviceId in memcpy

上级 0897d18a
......@@ -86,18 +86,22 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
}
template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace, void* dst,
platform::GPUPlace,
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}
template <>
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace, void* dst,
platform::CPUPlace,
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::CPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}
......@@ -108,6 +112,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const void* src, size_t num,
cudaStream_t stream) {
if (dst_place == src_place) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
} else {
platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册