提交 ed7727e8 编写于 作者: Y Yu Yang

Fix bug in system allocator

上级 95a0d7c7
...@@ -79,7 +79,18 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) { ...@@ -79,7 +79,18 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
// if size is 0. We just make sure it does. // if size is 0. We just make sure it does.
if (size <= 0) return nullptr; if (size <= 0) return nullptr;
void* p; void* p;
int prev_id;
cudaGetDevice(&prev_id);
if (prev_id != gpu_id_) {
cudaSetDevice(gpu_id_);
}
cudaError_t result = cudaMalloc(&p, size); cudaError_t result = cudaMalloc(&p, size);
if (prev_id != gpu_id_) {
cudaSetDevice(prev_id);
}
if (result == cudaSuccess) { if (result == cudaSuccess) {
index = 0; index = 0;
gpu_alloc_size_ += size; gpu_alloc_size_ += size;
......
...@@ -43,6 +43,8 @@ class CPUAllocator : public SystemAllocator { ...@@ -43,6 +43,8 @@ class CPUAllocator : public SystemAllocator {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
class GPUAllocator : public SystemAllocator { class GPUAllocator : public SystemAllocator {
public: public:
explicit GPUAllocator(int gpu_id) : gpu_id_(gpu_id) {}
virtual void* Alloc(size_t& index, size_t size); virtual void* Alloc(size_t& index, size_t size);
virtual void Free(void* p, size_t size, size_t index); virtual void Free(void* p, size_t size, size_t index);
virtual bool UseGpu() const; virtual bool UseGpu() const;
...@@ -50,6 +52,7 @@ class GPUAllocator : public SystemAllocator { ...@@ -50,6 +52,7 @@ class GPUAllocator : public SystemAllocator {
private: private:
size_t gpu_alloc_size_ = 0; size_t gpu_alloc_size_ = 0;
size_t fallback_alloc_size_ = 0; size_t fallback_alloc_size_ = 0;
int gpu_id_;
}; };
#endif #endif
......
...@@ -69,7 +69,7 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { ...@@ -69,7 +69,7 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
} }
platform::SetDeviceId(gpu_id); platform::SetDeviceId(gpu_id);
if (!as[gpu_id]) { if (!as[gpu_id]) {
as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator, as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator(gpu_id),
platform::GpuMinChunkSize(), platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize()); platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE: each GPU device use " VLOG(10) << "\n\nNOTE: each GPU device use "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册