diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 8ac8978120ad5930cd80272189ac0a83a77b2617..9949d80434c43ce846895c8d4c84221008a7fd8a 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -79,7 +79,18 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) { // if size is 0. We just make sure it does. if (size <= 0) return nullptr; void* p; + int prev_id; + cudaGetDevice(&prev_id); + if (prev_id != gpu_id_) { + cudaSetDevice(gpu_id_); + } + cudaError_t result = cudaMalloc(&p, size); + + if (prev_id != gpu_id_) { + cudaSetDevice(prev_id); + } + if (result == cudaSuccess) { index = 0; gpu_alloc_size_ += size; diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index e93c2c1e3231f7f42794dd78121072dbdb6abc41..c103d0864012d23d0390076840ee1a61b12ad048 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -43,6 +43,8 @@ class CPUAllocator : public SystemAllocator { #ifdef PADDLE_WITH_CUDA class GPUAllocator : public SystemAllocator { public: + explicit GPUAllocator(int gpu_id) : gpu_id_(gpu_id) {} + virtual void* Alloc(size_t& index, size_t size); virtual void Free(void* p, size_t size, size_t index); virtual bool UseGpu() const; @@ -50,6 +52,7 @@ class GPUAllocator : public SystemAllocator { private: size_t gpu_alloc_size_ = 0; size_t fallback_alloc_size_ = 0; + int gpu_id_; }; #endif diff --git a/paddle/fluid/memory/memory.cc b/paddle/fluid/memory/memory.cc index d07f89439a1ec37682f79799d5569cad2ab75818..1985f1f4e68db1e62ee7cfd3649312581840d02c 100644 --- a/paddle/fluid/memory/memory.cc +++ b/paddle/fluid/memory/memory.cc @@ -69,7 +69,7 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { } platform::SetDeviceId(gpu_id); if (!as[gpu_id]) { - as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator, + as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator(gpu_id), platform::GpuMinChunkSize(), platform::GpuMaxChunkSize()); VLOG(10) << "\n\nNOTE: each GPU device use "