diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index b3351f44dc35a4c040868ce96d44eea9ce83624f..c0d1934a703b66a8ab8a1eab0c1d0680d73b9e17 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -42,7 +42,7 @@ endif() cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) if (WITH_GPU OR WITH_ROCM) - set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator) + set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator device_context) if(CUDA_VERSION GREATER_EQUAL 10.2) list(APPEND AllocatorFacadeDeps cuda_virtual_mem_allocator) endif() diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index c836593f3f409fcef5876c87112cbb119564e51c..a2d198aba032308d2f5233e71b832f6515b54f6b 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -31,6 +31,7 @@ #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" #include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device_context.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" @@ -175,12 +176,12 @@ class AllocatorFacadePrivate { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) allow_free_idle_chunk_ = allow_free_idle_chunk; if (FLAGS_use_stream_safe_cuda_allocator) { + default_streams_ = + std::vector(platform::GetGPUDeviceCount(), nullptr); // TODO(Ruibiao): Support multi-stream allocator for other strategies - default_stream_ = nullptr; for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount(); ++dev_id) { - InitStreamSafeCUDAAllocator(platform::CUDAPlace(dev_id), - default_stream_); + InitStreamSafeCUDAAllocator(platform::CUDAPlace(dev_id), nullptr); } } else { for (int dev_id = 0; dev_id < platform::GetGPUDeviceCount(); @@ -285,15 +286,51 @@ class AllocatorFacadePrivate { return stream_it->second; } - gpuStream_t GetDefaultStream() { return default_stream_; } + const gpuStream_t& GetDefaultStream(const platform::CUDAPlace& place) { + int dev_id = place.GetDeviceId(); + gpuStream_t& default_stream = default_streams_[dev_id]; + if (UNLIKELY(default_stream == nullptr)) { + /* NOTE(Ruibiao): Here if we set default_stream by code " default_stream = + * platform::stream::get_current_stream(place.GetDeviceId())->raw_stream() + * ", then it will be fail to make target 'jit_kernel_benchmark', says a + * undefined reference to `paddle::platform::DeviceContextPool::Get( + * paddle::platform::Place const&)' in function + * `paddle::platform::stream::get_current_stream(int)'. However, target + * allocator_facade will not be affected. It seems a circular dependency + * problem between 'cuda_stream' and 'device_context' that causes this + * strange bug. + */ + platform::DeviceContextPool& pool = + platform::DeviceContextPool::Instance(); + default_stream = + static_cast(pool.Get(place))->stream(); + InitStreamSafeCUDAAllocator(place, default_stream); + } + return default_stream; + } - void RecordStream(Allocation* allocation, const gpuStream_t& stream) { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(allocation->place()), true, - platform::errors::InvalidArgument( - "Not allow to record stream for an allocation with place %s", - allocation->place())); - dynamic_cast(allocation)->RecordStream(stream); + void RecordStream(std::shared_ptr allocation, + const gpuStream_t& stream) { + StreamSafeCUDAAllocation* stream_safe_cuda_allocation = + dynamic_cast(allocation.get()); + PADDLE_ENFORCE_NOT_NULL(stream_safe_cuda_allocation, + platform::errors::InvalidArgument( + "Failed to dynamic cast %p from Allocation* to " + "StreamSafeCUDAAllocation*", + allocation.get())); + stream_safe_cuda_allocation->RecordStream(stream); + } + + const gpuStream_t& GetStream( + const std::shared_ptr& allocation) const { + const StreamSafeCUDAAllocation* stream_safe_cuda_allocation = + dynamic_cast(allocation.get()); + PADDLE_ENFORCE_NOT_NULL(stream_safe_cuda_allocation, + platform::errors::InvalidArgument( + "Failed to dynamic cast %p from Allocation* to " + "StreamSafeCUDAAllocation*", + allocation.get())); + return stream_safe_cuda_allocation->GetOwningStream(); } #ifdef PADDLE_WITH_CUDA @@ -705,7 +742,7 @@ class AllocatorFacadePrivate { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // a standalone CUDA allocator to support multi-stream GC in new executor CUDAAllocatorMap cuda_allocators_; - gpuStream_t default_stream_; + std::vector default_streams_; SpinLock cuda_allocators_lock_; #ifdef PADDLE_WITH_CUDA std::unordered_map> @@ -745,8 +782,9 @@ const std::shared_ptr& AllocatorFacade::GetAllocator( } #endif - return m_->GetAllocator(BOOST_GET_CONST(platform::CUDAPlace, place), - m_->GetDefaultStream()); + platform::CUDAPlace cuda_place = + BOOST_GET_CONST(platform::CUDAPlace, place); + return m_->GetAllocator(cuda_place, m_->GetDefaultStream(cuda_place)); } #endif @@ -769,8 +807,9 @@ AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, } #endif - return Alloc(BOOST_GET_CONST(platform::CUDAPlace, place), size, - m_->GetDefaultStream()); + platform::CUDAPlace cuda_place = + BOOST_GET_CONST(platform::CUDAPlace, place); + return Alloc(cuda_place, size, m_->GetDefaultStream(cuda_place)); } #endif @@ -789,8 +828,9 @@ uint64_t AllocatorFacade::Release(const platform::Place& place) { } #endif - return Release(BOOST_GET_CONST(platform::CUDAPlace, place), - m_->GetDefaultStream()); + platform::CUDAPlace cuda_place = + BOOST_GET_CONST(platform::CUDAPlace, place); + return Release(cuda_place, m_->GetDefaultStream(cuda_place)); } #endif return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1) @@ -804,9 +844,9 @@ std::shared_ptr AllocatorFacade::AllocShared( FLAGS_use_stream_safe_cuda_allocator, true, platform::errors::Unimplemented( "StreamSafeCUDAAllocator is disabled, you should not call this " - "multi-stream 'AllocaShared' function. " - "To enable it, you can enter 'export " - "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + "multi-stream 'AllocaShared' function. To enable it, you can enter" + "'export FLAGS_use_stream_safe_cuda_allocator=true' in the " + "terminal.")); #ifdef PADDLE_WITH_CUDA if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { @@ -824,9 +864,9 @@ AllocationPtr AllocatorFacade::Alloc(const platform::CUDAPlace& place, FLAGS_use_stream_safe_cuda_allocator, true, platform::errors::Unimplemented( "StreamSafeCUDAAllocator is disabled, you should not call this " - "multi-stream 'Alloca' function. " - "To enable it, you can enter 'export " - "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + "multi-stream 'Alloc' function. To enable it, you can enter" + "'export FLAGS_use_stream_safe_cuda_allocator=true' in the " + "terminal.")); #ifdef PADDLE_WITH_CUDA if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { @@ -836,7 +876,7 @@ AllocationPtr AllocatorFacade::Alloc(const platform::CUDAPlace& place, #endif if (LIKELY(size > 0 && FLAGS_use_system_allocator == false)) { - return m_->GetAllocator(place, stream, /* creat_if_not_found = */ true) + return m_->GetAllocator(place, stream, /* create_if_not_found = */ true) ->Allocate(size); } else { return m_->GetAllocator(place, size)->Allocate(size); @@ -849,9 +889,9 @@ uint64_t AllocatorFacade::Release(const platform::CUDAPlace& place, FLAGS_use_stream_safe_cuda_allocator, true, platform::errors::Unimplemented( "StreamSafeCUDAAllocator is disabled, you should not call this " - "multi-stream 'Release' function. " - "To enable it, you can enter 'export " - "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + "multi-stream 'Release' function. To enable it, you can enter" + "'export FLAGS_use_stream_safe_cuda_allocator=true' in the " + "terminal.")); #ifdef PADDLE_WITH_CUDA if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { @@ -863,15 +903,15 @@ uint64_t AllocatorFacade::Release(const platform::CUDAPlace& place, return m_->GetAllocator(place, stream)->Release(place); } -void AllocatorFacade::RecordStream(Allocation* allocation, +void AllocatorFacade::RecordStream(std::shared_ptr allocation, const gpuStream_t& stream) { PADDLE_ENFORCE_EQ( FLAGS_use_stream_safe_cuda_allocator, true, platform::errors::Unimplemented( "StreamSafeCUDAAllocator is disabled, you should not call this " - "'RecordStream' function. " - "To enable it, you can enter 'export " - "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + "'RecordStream' function. To enable it, you can enter" + "'export FLAGS_use_stream_safe_cuda_allocator=true' in the " + "terminal.")); #ifdef PADDLE_WITH_CUDA if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { @@ -883,6 +923,26 @@ void AllocatorFacade::RecordStream(Allocation* allocation, m_->RecordStream(allocation, stream); } +const gpuStream_t& AllocatorFacade::GetStream( + const std::shared_ptr& allocation) const { + PADDLE_ENFORCE_EQ( + FLAGS_use_stream_safe_cuda_allocator, true, + platform::errors::Unimplemented( + "StreamSafeCUDAAllocator is disabled, you should not call this " + "'GetStream' function. To enable it, you can enter" + "'export FLAGS_use_stream_safe_cuda_allocator=true' in the " + "terminal.")); + +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { + PADDLE_THROW(platform::errors::Unavailable( + "Not allow to use StreamSafeCUDAAllocator with CUDAGraphAllocator")); + } +#endif + + return m_->GetStream(allocation); +} + #ifdef PADDLE_WITH_CUDA void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { return m_->PrepareMemoryPoolForCUDAGraph(id); diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 0d9f1043d9e86ab4fca5d06ed6768d90812668b2..4c4f805a0c61968d197720c46411f1021d323041 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -64,7 +64,10 @@ class AllocatorFacade { AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size, const gpuStream_t& stream); uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream); - void RecordStream(Allocation* allocation, const gpuStream_t& stream); + void RecordStream(std::shared_ptr allocation, + const gpuStream_t& stream); + const gpuStream_t& GetStream( + const std::shared_ptr& allocation) const; #ifdef PADDLE_WITH_CUDA void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id); void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id); diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc index 86f3135ee4d147028eb12057f9be7ae451f6860e..7b6b61d7a60ca8ab68388820811f7c684f65cc95 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" -#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace memory { @@ -24,36 +23,92 @@ StreamSafeCUDAAllocation::StreamSafeCUDAAllocation( : Allocation(underlying_allocation->ptr(), underlying_allocation->size(), underlying_allocation->place()), underlying_allocation_(std::move(underlying_allocation)), - owning_stream_(owning_stream), - recorded_streams_(std::make_shared>()) {} + owning_stream_(std::move(owning_stream)) {} -void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) { - VLOG(8) << "Record stream " << stream << " to " << ptr(); +void StreamSafeCUDAAllocation::RecordStream(const gpuStream_t& stream) { + VLOG(8) << "Try record stream " << stream << " for address " << ptr(); if (stream == owning_stream_) { + VLOG(9) << "Record the same stream of " << stream; return; } - std::lock_guard lock_guard(spin_lock_); - recorded_streams_->insert(stream); + + std::lock_guard lock_guard(outstanding_event_map_lock_); + gpuEvent_t record_event; + auto it = outstanding_event_map_.find(stream); + if (it == outstanding_event_map_.end()) { + gpuEvent_t new_event; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventCreateWithFlags(&new_event, cudaEventDisableTiming)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + hipEventCreateWithFlags(&new_event, hipEventDisableTiming)); +#endif + outstanding_event_map_[stream] = new_event; + record_event = new_event; + VLOG(9) << "Create a new event " << new_event; + } else { + record_event = it->second; + VLOG(9) << "Reuse event " << record_event; + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(record_event, stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(record_event, stream)); +#endif + VLOG(8) << "Record event " << record_event << " to stream " << stream; +} + +bool StreamSafeCUDAAllocation::CanBeFreed() { + // NOTE(Ruibiao): This function will not execute concurrently, + // so outstanding_event_lock_ is not required here + for (auto it = outstanding_event_map_.begin(); + it != outstanding_event_map_.end(); ++it) { + gpuEvent_t& event = it->second; +#ifdef PADDLE_WITH_CUDA + gpuError_t err = cudaEventQuery(event); + if (err == cudaErrorNotReady) { + VLOG(9) << "Event " << event << " for " << ptr() << " is not completed"; + // Erase the completded event before "it" + outstanding_event_map_.erase(outstanding_event_map_.begin(), it); + return false; + } + PADDLE_ENFORCE_GPU_SUCCESS(err); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event)); +#else + gpuError_t err = hipEventQuery(event); + if (err == hipErrorNotReady) { + VLOG(9) << "Event " << event << " for " << ptr() << " is not completed"; + // Erase the completded event before "it" + outstanding_event_map_.erase(outstanding_event_map_.begin(), it); + return false; + } + PADDLE_ENFORCE_GPU_SUCCESS(err); + PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(event)); +#endif + VLOG(8) << "Destroy event " << event; + } + return true; } -std::shared_ptr> -StreamSafeCUDAAllocation::GetRecordedStreams() { - return recorded_streams_; +const gpuStream_t& StreamSafeCUDAAllocation::GetOwningStream() const { + return owning_stream_; } StreamSafeCUDAAllocator::StreamSafeCUDAAllocator( - const std::shared_ptr& underlying_allocator, - const platform::CUDAPlace& place, const gpuStream_t default_stream) - : underlying_allocator_(underlying_allocator), - place_(place), - default_stream_(default_stream) { - std::lock_guard lock_guard(allocators_map_lock_); - allocators_map_[place].emplace_back(this); + std::shared_ptr underlying_allocator, platform::CUDAPlace place, + gpuStream_t default_stream) + : underlying_allocator_(std::move(underlying_allocator)), + place_(std::move(place)), + default_stream_(std::move(default_stream)) { + std::lock_guard lock_guard(allocator_map_lock_); + allocator_map_[place].emplace_back(this); } StreamSafeCUDAAllocator::~StreamSafeCUDAAllocator() { - std::lock_guard lock_guard(allocators_map_lock_); - std::vector& allocators = allocators_map_[place_]; + std::lock_guard lock_guard(allocator_map_lock_); + std::vector& allocators = allocator_map_[place_]; allocators.erase(std::remove(allocators.begin(), allocators.end(), this), allocators.end()); } @@ -61,147 +116,80 @@ StreamSafeCUDAAllocator::~StreamSafeCUDAAllocator() { bool StreamSafeCUDAAllocator::IsAllocThreadSafe() const { return true; } Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) { - ProcessEventsAndFree(); + ProcessUnfreedAllocations(); AllocationPtr underlying_allocation; try { underlying_allocation = underlying_allocator_->Allocate(size); } catch (BadAlloc&) { - VLOG(9) << "Allocation failed when allocating " << size << " bytes"; - uint64_t release_size = ReleaseImpl(place_); - VLOG(9) << "Release " << release_size << " bytes memory from all streams"; + VLOG(4) << "Allocation failed when allocating " << size << " bytes"; + ReleaseImpl(place_); try { underlying_allocation = underlying_allocator_->Allocate(size); } catch (...) { - VLOG(9) << "Still allocation failed after release memory"; + VLOG(3) + << "Still allocation failed after release memory from all streams"; throw; } } catch (...) { throw; } - StreamSafeCUDAAllocation* allocation = new StreamSafeCUDAAllocation( std::move(underlying_allocation), default_stream_); + VLOG(8) << "Allocate " << allocation->size() << " bytes at address " + << allocation->ptr(); return allocation; } void StreamSafeCUDAAllocator::FreeImpl(Allocation* allocation) { - if (dynamic_cast(allocation) - ->GetRecordedStreams() - ->empty()) { - delete allocation; + StreamSafeCUDAAllocation* stream_safe_cuda_allocation = + dynamic_cast(allocation); + PADDLE_ENFORCE_NOT_NULL(stream_safe_cuda_allocation, + platform::errors::InvalidArgument( + "Failed to dynamic cast %p from Allocation* to " + "StreamSafeCUDAAllocation*", + allocation)); + VLOG(8) << "Try free allocation " << stream_safe_cuda_allocation->ptr(); + if (stream_safe_cuda_allocation->CanBeFreed()) { + delete stream_safe_cuda_allocation; } else { - std::lock_guard lock_guard(outstanding_events_map_lock_); - FreeStreamSafeCUDAAllocation(allocation); + std::lock_guard lock_guard(unfreed_allocation_lock_); + unfreed_allocations_.emplace_back(stream_safe_cuda_allocation); } } uint64_t StreamSafeCUDAAllocator::ReleaseImpl(const platform::Place& place) { - std::lock_guard lock_guard(allocators_map_lock_); + std::lock_guard lock_guard(allocator_map_lock_); std::vector& allocators = - allocators_map_[BOOST_GET_CONST(platform::CUDAPlace, place)]; - uint64_t release_size = 0; + allocator_map_[BOOST_GET_CONST(platform::CUDAPlace, place)]; + uint64_t released_size = 0; for (StreamSafeCUDAAllocator* allocator : allocators) { - release_size += allocator->ProcessEventsAndFreeWithRelease(); - } - VLOG(8) << "Release " << release_size - << " bytes memory from all stream for place " << place; - return release_size; -} - -void StreamSafeCUDAAllocator::CreateEventForAllRecordedStream( - std::set* recorded_streams, - std::deque* outstanding_events) { - for (gpuStream_t stream : *recorded_streams) { - gpuEvent_t event; -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS( - cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, stream)); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - hipEventCreateWithFlags(&event, hipEventDisableTiming)); - PADDLE_ENFORCE_GPU_SUCCESS(hipEventRecord(event, stream)); -#endif - outstanding_events->emplace_back(event); - VLOG(9) << "Record event " << event << " in stream " << stream; + released_size += allocator->ProcessUnfreedAllocationsWithRelease(); } - recorded_streams->clear(); + VLOG(8) << "Release " << released_size << " bytes memory from all streams"; + return released_size; } -void StreamSafeCUDAAllocator::FreeStreamSafeCUDAAllocation( - Allocation* allocation) { - std::deque& outstanding_events = - outstanding_events_map_[allocation]; - CreateEventForAllRecordedStream( - dynamic_cast(allocation) - ->GetRecordedStreams() - .get(), - &outstanding_events); - if (!outstanding_events.empty()) { - VLOG(8) << allocation->ptr() << " is not ready to free"; - return; - } - - VLOG(8) << "Free " << allocation->ptr(); - outstanding_events_map_.erase(allocation); - delete allocation; -} - -void StreamSafeCUDAAllocator::ProcessEventsAndFree() { - std::lock_guard lock_guard(outstanding_events_map_lock_); - for (auto map_it = outstanding_events_map_.begin(); - map_it != outstanding_events_map_.end();) { - std::deque& outstanding_events = map_it->second; - VLOG(10) << "Check " << outstanding_events.size() - << " outstanding events for " << map_it->first->ptr(); - auto deque_it = outstanding_events.begin(); - while (deque_it != outstanding_events.end()) { -#ifdef PADDLE_WITH_CUDA - gpuError_t err = cudaEventQuery(*deque_it); - if (err == cudaErrorNotReady) { - VLOG(10) << "Event " << *deque_it << " for " << map_it->first->ptr() - << " is not completed"; - outstanding_events.erase(outstanding_events.begin(), deque_it); - break; - } - PADDLE_ENFORCE_GPU_SUCCESS(err); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(*deque_it)); -#else - gpuError_t err = hipEventQuery(*deque_it); - if (err == hipErrorNotReady) { - VLOG(10) << "Event " << *deque_it << " for " << map_it->first->ptr() - << " is not completed"; - // Erase the completded event before "deque_it" - outstanding_events.erase(outstanding_events.begin(), deque_it); - break; - } - PADDLE_ENFORCE_GPU_SUCCESS(err); - PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(*deque_it)); -#endif - ++deque_it; - } - - if (deque_it == outstanding_events.end()) { - outstanding_events.clear(); - Allocation* allocation = map_it->first; - // "map_it" may be invalid after calling FreeStreamSafeCUDAAllocation - auto next_it = ++map_it; - FreeStreamSafeCUDAAllocation(allocation); - map_it = next_it; +void StreamSafeCUDAAllocator::ProcessUnfreedAllocations() { + std::lock_guard lock_guard(unfreed_allocation_lock_); + for (auto it = unfreed_allocations_.begin(); + it != unfreed_allocations_.end();) { + if ((*it)->CanBeFreed()) { + delete *it; + it = unfreed_allocations_.erase(it); } else { - ++map_it; + ++it; } } } -uint64_t StreamSafeCUDAAllocator::ProcessEventsAndFreeWithRelease() { - ProcessEventsAndFree(); +uint64_t StreamSafeCUDAAllocator::ProcessUnfreedAllocationsWithRelease() { + ProcessUnfreedAllocations(); return underlying_allocator_->Release(place_); } std::map> - StreamSafeCUDAAllocator::allocators_map_; -SpinLock StreamSafeCUDAAllocator::allocators_map_lock_; + StreamSafeCUDAAllocator::allocator_map_; +SpinLock StreamSafeCUDAAllocator::allocator_map_lock_; } // namespace allocation } // namespace memory diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h index a516558228be63ca5e9f5b874dbd2c25ce39edb5..d84994f58a9c40e7bc2f4adc64a01ca667104382 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h @@ -13,21 +13,21 @@ // limitations under the License. #pragma once -#ifdef PADDLE_WITH_CUDA -#include -#else -#include -#endif #include +#include #include -#include #include -#include #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/spin_lock.h" #include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#include +#else +#include +#endif + namespace paddle { namespace memory { namespace allocation { @@ -36,21 +36,23 @@ class StreamSafeCUDAAllocation : public Allocation { public: StreamSafeCUDAAllocation(AllocationPtr underlying_allocation, gpuStream_t owning_stream); - void RecordStream(gpuStream_t stream); - std::shared_ptr> GetRecordedStreams(); + void RecordStream(const gpuStream_t &stream); + bool CanBeFreed(); + + const gpuStream_t &GetOwningStream() const; private: AllocationPtr underlying_allocation_; + std::map outstanding_event_map_; gpuStream_t owning_stream_; - std::shared_ptr> recorded_streams_; - SpinLock spin_lock_; + SpinLock outstanding_event_map_lock_; }; class StreamSafeCUDAAllocator : public Allocator { public: - StreamSafeCUDAAllocator( - const std::shared_ptr &underlying_allocator, - const platform::CUDAPlace &place, const gpuStream_t default_stream); + StreamSafeCUDAAllocator(std::shared_ptr underlying_allocator, + platform::CUDAPlace place, + gpuStream_t default_stream); ~StreamSafeCUDAAllocator(); bool IsAllocThreadSafe() const override; @@ -60,22 +62,18 @@ class StreamSafeCUDAAllocator : public Allocator { uint64_t ReleaseImpl(const platform::Place &place) override; private: - void CreateEventForAllRecordedStream( - std::set *recorded_streams, - std::deque *outstanding_events); - void FreeStreamSafeCUDAAllocation(Allocation *allocation); - void ProcessEventsAndFree(); - uint64_t ProcessEventsAndFreeWithRelease(); + void ProcessUnfreedAllocations(); + uint64_t ProcessUnfreedAllocationsWithRelease(); static std::map> - allocators_map_; - static SpinLock allocators_map_lock_; + allocator_map_; + static SpinLock allocator_map_lock_; std::shared_ptr underlying_allocator_; platform::CUDAPlace place_; gpuStream_t default_stream_; - std::map> outstanding_events_map_; - SpinLock outstanding_events_map_lock_; + std::list unfreed_allocations_; + SpinLock unfreed_allocation_lock_; }; } // namespace allocation diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index 4921b87ccd99e9735fe5bba97e911d4cebd10841..5ec96c39bb6045255d7aff294fd02b3ecc2faac7 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -50,10 +50,16 @@ uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream) { return allocation::AllocatorFacade::Instance().Release(place, stream); } -void RecordStream(Allocation* allocation, const gpuStream_t& stream) { +void RecordStream(std::shared_ptr allocation, + const gpuStream_t& stream) { return allocation::AllocatorFacade::Instance().RecordStream(allocation, stream); } + +const gpuStream_t& GetStream(const std::shared_ptr& allocation) { + return allocation::AllocatorFacade::Instance().GetStream(allocation); +} + #endif } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/malloc.h b/paddle/fluid/memory/malloc.h index 2aa9fbe6ada8fe7124a2e7d6033f85df59feff9d..7ca15c5dfc1279f21221aa35c8cb445d44e2ca7f 100644 --- a/paddle/fluid/memory/malloc.h +++ b/paddle/fluid/memory/malloc.h @@ -51,7 +51,10 @@ extern AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size, extern uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream); -void RecordStream(Allocation* allocation, const gpuStream_t& stream); +void RecordStream(std::shared_ptr allocation, + const gpuStream_t& stream); + +const gpuStream_t& GetStream(const std::shared_ptr& allocation); #endif } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu index 134c368d4340e379feb05172d0963049ec65e4d5..286dcdba8f22fecc55ce487857fc1b434b87b546 100644 --- a/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu +++ b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu @@ -29,14 +29,16 @@ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace memory { __global__ void add_kernel(int *x, int n) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int i = tid; i < n; i += blockDim.x * gridDim.x) { - atomicAdd(x + i, tid); + int thread_num = gridDim.x * blockDim.x; + int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + for (int i = thread_id; i < n; i += thread_num) { + atomicAdd(x + i, thread_id); } } @@ -54,26 +56,21 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { place_ = platform::CUDAPlace(); stream_num_ = 64; grid_num_ = 1; - block_num_ = 64; - data_num_ = 64; - default_stream = nullptr; + block_num_ = 32; + data_num_ = 131072; + workspace_size_ = data_num_ * sizeof(int); - streams_.reserve(stream_num_); - streams_.emplace_back(default_stream); - for (size_t i = 1; i < stream_num_; ++i) { + // alloc workspace for each stream + for (size_t i = 0; i < stream_num_; ++i) { gpuStream_t stream; #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&stream)); #else PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream)); #endif - streams_.emplace_back(stream); - } - for (size_t i = 0; i < stream_num_; ++i) { - size_t allocation_size = data_num_ * sizeof(int); std::shared_ptr allocation = - AllocShared(place_, allocation_size, streams_[i]); + AllocShared(place_, workspace_size_, stream); #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS( cudaMemset(allocation->ptr(), 0, allocation->size())); @@ -81,25 +78,45 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { PADDLE_ENFORCE_GPU_SUCCESS( hipMemset(allocation->ptr(), 0, allocation->size())); #endif - allocations_.emplace_back(allocation); + + streams_.emplace_back(stream); + workspaces_.emplace_back(allocation); } + + result_ = AllocShared(place_, stream_num_ * workspace_size_); } void SingleStreamRun(size_t idx) { + // for all stream i, + // stream idx lauch a kernel to add (j % thread_num) to workspaces_[i][j] for (size_t i = 0; i < stream_num_; ++i) { - int *x = reinterpret_cast(allocations_[i]->ptr()); + int *x = reinterpret_cast(workspaces_[i]->ptr()); add_kernel<<>>(x, data_num_); - if (i != idx) { - RecordStream(allocations_[i].get(), streams_[idx]); - } + RecordStream(workspaces_[i], streams_[idx]); + } + } + + void CopyResultAsync() { + for (size_t i = 0; i < stream_num_; ++i) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync( + reinterpret_cast(result_->ptr()) + i * data_num_, + workspaces_[i]->ptr(), workspace_size_, cudaMemcpyDeviceToDevice)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpyAsync( + reinterpret_cast(result_->ptr()) + i * data_num_, + workspaces_[i]->ptr(), workspace_size_, hipMemcpyDeviceToDevice)); +#endif } } void MultiStreamRun() { - for (int i = 0; i < stream_num_; ++i) { + for (size_t i = 0; i < stream_num_; ++i) { SingleStreamRun(i); } - allocations_.clear(); // fast_gc + CopyResultAsync(); + workspaces_.clear(); // fast_gc + cudaDeviceSynchronize(); } void MultiThreadMUltiStreamRun() { @@ -111,28 +128,30 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { for (size_t i = 0; i < stream_num_; ++i) { threads[i].join(); } - allocations_.clear(); // fast_gc + CopyResultAsync(); + workspaces_.clear(); // fast_gc + cudaDeviceSynchronize(); } void CheckResult() { - auto host_x = std::unique_ptr(new int[data_num_]); - size_t thread_num = grid_num_ * block_num_; - for (int i = 0; i < stream_num_; ++i) { -// tricky code, the allocations are still accessible even though -// allocations_.clear() has been called + auto result_host = std::unique_ptr(new int[result_->size()]); #ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpy(host_x.get(), allocations_[i]->ptr(), - data_num_ * sizeof(int), cudaMemcpyDeviceToHost)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(result_host.get(), result_->ptr(), + result_->size(), + cudaMemcpyDeviceToHost)); #else - PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(host_x.get(), allocations_[i]->ptr(), - data_num_ * sizeof(int), - hipMemcpyDeviceToHost)); + PADDLE_ENFORCE_GPU_SUCCESS(hipMemcpy(result_host.get(), result_->ptr(), + result_->size(), + hipMemcpyDeviceToHost)); #endif - for (int j = 0; j < data_num_; ++j) { - EXPECT_TRUE(host_x[j] == (j % thread_num) * stream_num_); + size_t thread_num = grid_num_ * block_num_; + for (size_t i = 0; i < stream_num_; ++i) { + for (size_t j = 0; j < data_num_; ++j) { + EXPECT_TRUE(result_host[i * stream_num_ + j] == + (j % thread_num) * stream_num_); } } + result_.reset(); } void TearDown() override { @@ -160,10 +179,11 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { size_t grid_num_; size_t block_num_; size_t data_num_; + size_t workspace_size_; platform::CUDAPlace place_; - gpuStream_t default_stream; std::vector streams_; - std::vector> allocations_; + std::vector> workspaces_; + std::shared_ptr result_; }; TEST_F(StreamSafeCUDAAllocTest, CUDAMutilStreamTest) { @@ -187,7 +207,10 @@ TEST(StreamSafeCUDAAllocInterfaceTest, AllocInterfaceTest) { void *address = allocation_implicit_stream->ptr(); allocation_implicit_stream.reset(); - gpuStream_t default_stream = nullptr; + gpuStream_t default_stream = + dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)) + ->stream(); allocation::AllocationPtr allocation_unique = Alloc(place, alloc_size, default_stream); EXPECT_GE(allocation_unique->size(), alloc_size); @@ -220,6 +243,41 @@ TEST(StreamSafeCUDAAllocInterfaceTest, GetAllocatorInterfaceTest) { CheckMemLeak(place); } +TEST(StreamSafeCUDAAllocInterfaceTest, GetStreamInterfaceTest) { + platform::CUDAPlace place = platform::CUDAPlace(); + size_t alloc_size = 256; + + gpuStream_t default_stream = + dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + std::shared_ptr allocation_implicit_stream = + AllocShared(place, alloc_size); + EXPECT_EQ(GetStream(allocation_implicit_stream), default_stream); + + gpuStream_t new_stream; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreate(&new_stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&new_stream)); +#endif + + std::shared_ptr allocation_new_stream = + AllocShared(place, alloc_size, new_stream); + EXPECT_EQ(GetStream(allocation_new_stream), new_stream); + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamDestroy(new_stream)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamDestroy(new_stream)); +#endif + + allocation_implicit_stream.reset(); + allocation_new_stream.reset(); + Release(place); + CheckMemLeak(place); +} + #ifdef PADDLE_WITH_CUDA TEST(StreamSafeCUDAAllocInterfaceTest, CUDAGraphExceptionTest) { platform::CUDAPlace place = platform::CUDAPlace(); @@ -237,8 +295,9 @@ TEST(StreamSafeCUDAAllocInterfaceTest, CUDAGraphExceptionTest) { EXPECT_THROW(Alloc(place, alloc_size, nullptr), paddle::platform::EnforceNotMet); EXPECT_THROW(Release(place, nullptr), paddle::platform::EnforceNotMet); - EXPECT_THROW(RecordStream(allocation.get(), nullptr), + EXPECT_THROW(RecordStream(allocation, nullptr), paddle::platform::EnforceNotMet); + EXPECT_THROW(GetStream(allocation), paddle::platform::EnforceNotMet); platform::EndCUDAGraphCapture(); allocation.reset(); @@ -258,7 +317,8 @@ TEST(StreamSafeCUDAAllocRetryTest, RetryTest) { PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreate(&stream2)); #endif size_t available_size = platform::GpuAvailableMemToAlloc(); - // alloc_size < available_size < 2 * alloc_size + // alloc_size < available_size < 2 * alloc_size, + // so the second alloc will fail and retry size_t alloc_size = available_size / 4 * 3; std::shared_ptr allocation1 =