From 32633c8e8568006cf51d83356a878d9f12d73e0a Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Thu, 15 Dec 2022 10:33:26 +0800 Subject: [PATCH] SetDeviceId in StreamSafeCUDAAllocation (#49080) --- .../memory/allocation/stream_safe_cuda_allocator.cc | 9 +++++++++ .../fluid/memory/allocation/stream_safe_cuda_allocator.h | 1 + 2 files changed, 10 insertions(+) diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc index b62ba99df7..1967dd8502 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc @@ -16,6 +16,7 @@ #include #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" @@ -43,6 +44,9 @@ void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) { return; } + std::call_once(once_flag_, + [this] { phi::backends::gpu::SetDeviceId(place_.device); }); + std::lock_guard lock_guard(outstanding_event_map_lock_); #ifdef PADDLE_WITH_CUDA if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) { @@ -63,6 +67,9 @@ bool StreamSafeCUDAAllocation::CanBeFreed() { } #endif + std::call_once(once_flag_, + [this] { phi::backends::gpu::SetDeviceId(place_.device); }); + RecordGraphCapturingStreams(); for (auto it = outstanding_event_map_.begin(); @@ -259,6 +266,8 @@ uint64_t StreamSafeCUDAAllocator::ProcessUnfreedAllocationsAndRelease() { return underlying_allocator_->Release(place_); } +std::once_flag StreamSafeCUDAAllocation::once_flag_; + std::map> StreamSafeCUDAAllocator::allocator_map_; SpinLock StreamSafeCUDAAllocator::allocator_map_lock_; diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h index a6be0cadba..5f9b620810 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h @@ -45,6 +45,7 @@ class StreamSafeCUDAAllocation : public Allocation { gpuStream_t GetOwningStream() const; private: + static std::once_flag once_flag_; void RecordGraphCapturingStreams(); void RecordStreamWithNoGraphCapturing(gpuStream_t stream); DecoratedAllocationPtr underlying_allocation_; -- GitLab