From b3115fb01c007abea7e7ea7bf41363c5669e844a Mon Sep 17 00:00:00 2001 From: liaogang Date: Thu, 20 Jul 2017 11:21:37 +0800 Subject: [PATCH] Add SetDeviceId in memcpy --- paddle/memory/memory.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 4056a54b4a3..78443cc35a4 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -86,18 +86,22 @@ size_t Used(platform::GPUPlace place) { } template <> -void Copy(platform::CPUPlace, void* dst, - platform::GPUPlace, +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, const void* src, size_t num, cudaStream_t stream) { + platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); } template <> -void Copy(platform::GPUPlace, void* dst, - platform::CPUPlace, +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, const void* src, size_t num, cudaStream_t stream) { + platform::SetDeviceId(dst_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); } @@ -108,6 +112,7 @@ void Copy(platform::GPUPlace dst_place, const void* src, size_t num, cudaStream_t stream) { if (dst_place == src_place) { + platform::SetDeviceId(src_place.device); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); } else { platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, -- GitLab