From b9c464c3179523df959c637f92e138935d70b278 Mon Sep 17 00:00:00 2001 From: From00 Date: Thu, 25 Nov 2021 15:49:11 +0800 Subject: [PATCH] Support multi-stream allocation for CUDA place (#37290) * Support multi-stream allocation for CUDA place * Do not notify the retrying from other streams when free CUDA allocation * Fix compile error for CPU * Fix compile error for HIP * Release memory for StreamSafeCUDAAllocaRetry in malloc_test * Add FLAGS_use_stream_safe_cuda_allocator * Fix CI error for 'set_tests_properties' * Invalidate stream safe CUDA allocator for naive_best_fit and thread_local strategy * Performance improvement: insert allocation pair to outstanding_events_map when free but not alloc; replace recursive_mutex with SpinLock * FLAGS priority changes: FLAGS_use_system_allocator > FLAGS_use_stream_safe_cuda_allocator * Performance improvement: directly delete allocation when the recorded_streams is empty in FreeImpl of StreamSafeCUDAAllocator * Add UT for alloc interface * Changes multi-stream interface; move retry code from AllocatorFacadePrivate to StreamSafeCUDAAllocator --- paddle/fluid/memory/CMakeLists.txt | 10 + paddle/fluid/memory/allocation/CMakeLists.txt | 10 +- .../memory/allocation/allocator_facade.cc | 514 ++++++++++++++---- .../memory/allocation/allocator_facade.h | 17 +- .../allocation/stream_safe_cuda_allocator.cc | 206 +++++++ .../allocation/stream_safe_cuda_allocator.h | 83 +++ paddle/fluid/memory/malloc.cc | 28 +- paddle/fluid/memory/malloc.h | 13 + .../memory/stream_safe_cuda_alloc_test.cu | 229 ++++++++ 9 files changed, 985 insertions(+), 125 deletions(-) create mode 100644 paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc create mode 100644 paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h create mode 100644 paddle/fluid/memory/stream_safe_cuda_alloc_test.cu diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 75b1bffca3..69134e1c76 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -17,6 +17,16 @@ if (WITH_GPU) nv_test(malloc_test SRCS malloc_test.cu DEPS device_context malloc) + nv_test(stream_safe_cuda_alloc_test + SRCS stream_safe_cuda_alloc_test.cu + DEPS malloc) + + if(WITH_TESTING AND TEST stream_safe_cuda_alloc_test) + set_tests_properties(stream_safe_cuda_alloc_test PROPERTIES + ENVIRONMENT "FLAGS_use_system_allocator=false" + ENVIRONMENT "FLAGS_enable_stream_safe_cuda_allocator=true" + ENVIRONMENT "FLAGS_allocator_strategy=auto_growth") + endif() endif() if (WITH_ROCM) diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 58979d6c3e..4d44c533b7 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -15,8 +15,10 @@ endif() if (WITH_GPU) nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) - nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) + nv_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator) + nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) + cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator) if(CUDA_VERSION GREATER_EQUAL 10.2) nv_library(cuda_virtual_mem_allocator SRCS cuda_virtual_mem_allocator.cc DEPS dynload_cuda) @@ -25,8 +27,10 @@ endif() if (WITH_ROCM) hip_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) - hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) hip_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) + hip_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator) + hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) + cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator) endif() @@ -38,7 +42,7 @@ endif() cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) if (WITH_GPU OR WITH_ROCM) - set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator) + set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator) if(CUDA_VERSION GREATER_EQUAL 10.2) list(APPEND AllocatorFacadeDeps cuda_virtual_mem_allocator) endif() diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index ca7f5655f0..8fcafced8a 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -15,36 +15,45 @@ #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "gflags/gflags.h" +#include "paddle/fluid/memory/allocation/aligned_allocator.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h" #include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/naive_best_fit_allocator.h" -#ifdef PADDLE_WITH_ASCEND_CL -#include "paddle/fluid/memory/allocation/npu_pinned_allocator.h" -#endif -#include "paddle/fluid/memory/allocation/aligned_allocator.h" #include "paddle/fluid/memory/allocation/retry_allocator.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/npu_info.h" #include "paddle/fluid/platform/place.h" + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/allocation/cuda_allocator.h" #include "paddle/fluid/memory/allocation/pinned_allocator.h" +#include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" #include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/platform/gpu_info.h" + +#ifdef PADDLE_WITH_CUDA +#include +#include "paddle/fluid/platform/cuda_graph.h" +#else +#include #endif + #if CUDA_VERSION >= 10020 #include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" #include "paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h" #include "paddle/fluid/platform/dynload/cuda_driver.h" #endif -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/cuda_graph.h" #endif + #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/device/xpu/xpu_info.h" #endif -#include "paddle/fluid/platform/npu_info.h" + +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/memory/allocation/npu_pinned_allocator.h" +#endif PADDLE_DEFINE_EXPORTED_int64( gpu_allocator_retry_time, 10000, @@ -59,6 +68,12 @@ PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, "Use VirtualMemoryAutoGrowthBestFitAllocator."); +// NOTE(Ruibiao): This FLAGS is just to be compatibled with +// the old single-stream CUDA allocator. It will be removed +// after StreamSafeCudaAllocator has been fully tested. +PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, true, + "Enable StreamSafeCUDAAllocator"); + DECLARE_string(allocator_strategy); namespace paddle { @@ -114,23 +129,34 @@ class AllocatorFacadePrivate { public: using AllocatorMap = std::map>; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + using CUDAAllocatorMap = + std::map>>; +#endif + explicit AllocatorFacadePrivate(bool allow_free_idle_chunk = true) { strategy_ = GetAllocatorStrategy(); switch (strategy_) { case AllocatorStrategy::kNaiveBestFit: { InitNaiveBestFitCPUAllocator(); -#ifdef PADDLE_WITH_XPU - for (int dev_id = 0; dev_id < platform::GetXPUDeviceCount(); ++dev_id) { - InitNaiveBestFitXPUAllocator(platform::XPUPlace(dev_id)); - } -#endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (FLAGS_use_stream_safe_cuda_allocator) { + LOG(WARNING) << "FLAGS_use_stream_safe_cuda_allocator is invalid for " + "naive_best_fit strategy"; + FLAGS_use_stream_safe_cuda_allocator = false; + } for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { InitNaiveBestFitCUDAAllocator(platform::CUDAPlace(dev_id)); } InitNaiveBestFitCUDAPinnedAllocator(); #endif +#ifdef PADDLE_WITH_XPU + for (int dev_id = 0; dev_id < platform::GetXPUDeviceCount(); ++dev_id) { + InitNaiveBestFitXPUAllocator(platform::XPUPlace(dev_id)); + } +#endif #ifdef PADDLE_WITH_ASCEND_CL for (int dev_id = 0; dev_id < platform::GetNPUDeviceCount(); ++dev_id) { InitNaiveBestFitNPUAllocator(platform::NPUPlace(dev_id)); @@ -142,18 +168,29 @@ class AllocatorFacadePrivate { case AllocatorStrategy::kAutoGrowth: { InitNaiveBestFitCPUAllocator(); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + allow_free_idle_chunk_ = allow_free_idle_chunk; + if (FLAGS_use_stream_safe_cuda_allocator) { + // TODO(Ruibiao): Support multi-stream allocator for other strategies + default_stream_ = nullptr; + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); + ++dev_id) { + InitStreamSafeCUDAAllocator(platform::CUDAPlace(dev_id), + default_stream_); + } + } else { + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); + ++dev_id) { + InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id), + allow_free_idle_chunk_); + } + } + InitNaiveBestFitCUDAPinnedAllocator(); +#endif #ifdef PADDLE_WITH_XPU for (int dev_id = 0; dev_id < platform::GetXPUDeviceCount(); ++dev_id) { InitNaiveBestFitXPUAllocator(platform::XPUPlace(dev_id)); } -#endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); - ++dev_id) { - InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id), - allow_free_idle_chunk); - } - InitNaiveBestFitCUDAPinnedAllocator(); #endif break; } @@ -166,6 +203,12 @@ class AllocatorFacadePrivate { } #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (FLAGS_use_stream_safe_cuda_allocator) { + LOG(WARNING) << "FLAGS_use_stream_safe_cuda_allocator is invalid for " + "thread_local strategy"; + FLAGS_use_stream_safe_cuda_allocator = false; + } + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { InitThreadLocalCUDAAllocator(platform::CUDAPlace(dev_id)); @@ -190,24 +233,6 @@ class AllocatorFacadePrivate { CheckAllocThreadSafe(); } - inline const AllocatorMap& GetAllocatorMap() { -#ifdef PADDLE_WITH_CUDA - if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { - auto id = platform::CUDAGraph::CapturingID(); - auto iter = cuda_graph_allocator_map_.find(id); - PADDLE_ENFORCE_NE( - iter, cuda_graph_allocator_map_.end(), - platform::errors::PermissionDenied( - "No memory pool is prepared for CUDA Graph capturing.")); - return iter->second->allocators_; - } else { - return allocators_; - } -#else - return allocators_; -#endif - } - inline const std::shared_ptr& GetAllocator( const platform::Place& place, size_t size) { VLOG(6) << "GetAllocator" @@ -223,25 +248,106 @@ class AllocatorFacadePrivate { return iter->second; } - private: - void InitSystemAllocators() { - if (!system_allocators_.empty()) return; - system_allocators_[platform::CPUPlace()] = std::make_shared(); -#ifdef PADDLE_WITH_XPU - int device_count = platform::GetXPUDeviceCount(); - for (int i = 0; i < device_count; ++i) { - platform::XPUPlace p(i); - system_allocators_[p] = std::make_shared(p); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + const std::shared_ptr& GetAllocator( + const platform::CUDAPlace& place, const gpuStream_t& stream, + bool create_if_not_found = false) { + auto place_it = cuda_allocators_.find(place); + PADDLE_ENFORCE_NE(place_it, cuda_allocators_.end(), + platform::errors::NotFound( + "No allocator found for the place %s", place)); + + const std::map>& allocator_map = + place_it->second; + auto stream_it = allocator_map.find(stream); + if (stream_it == allocator_map.end()) { + if (create_if_not_found) { + InitStreamSafeCUDAAllocator(place, stream); + return cuda_allocators_[place][stream]; + } else { + PADDLE_THROW(platform::errors::NotFound( + "No allocator found for stream %s in place %s", stream, place)); + } + } + return stream_it->second; + } + + gpuStream_t GetDefaultStream() { return default_stream_; } + + void RecordStream(Allocation* allocation, const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(allocation->place()), true, + platform::errors::InvalidArgument( + "Not allow to record stream for an allocation with place %s", + allocation->place())); + dynamic_cast(allocation)->RecordStream(stream); + } + +#ifdef PADDLE_WITH_CUDA + void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { + PADDLE_ENFORCE_EQ(strategy_, AllocatorStrategy::kAutoGrowth, + platform::errors::InvalidArgument( + "CUDA Graph is only supported when the " + "FLAGS_allocator_strategy=\"auto_growth\", but got " + "FLAGS_allocator_strategy=\"%s\"", + FLAGS_allocator_strategy)); + auto& allocator = cuda_graph_allocator_map_[id]; + PADDLE_ENFORCE_EQ( + allocator.get(), nullptr, + platform::errors::InvalidArgument( + "The memory pool of the CUDA Graph with ID %d have been prepared.", + id)); + allocator.reset( + new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false)); + for (auto& item : allocator->allocators_) { + auto& old_allocator = item.second; + old_allocator = CUDAGraphAllocator::Create(old_allocator); } + VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id; + } + + void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { + auto iter = cuda_graph_allocator_map_.find(id); + PADDLE_ENFORCE_NE(iter, cuda_graph_allocator_map_.end(), + platform::errors::InvalidArgument( + "Cannot find CUDA Graph with ID = %d", id)); + cuda_graph_allocator_map_.erase(iter); + VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id; + } #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - system_allocators_[platform::CUDAPinnedPlace()] = - std::make_shared(); - int device_count = platform::GetCUDADeviceCount(); - for (int i = 0; i < device_count; ++i) { - platform::CUDAPlace p(i); - system_allocators_[p] = std::make_shared(p); +#endif + + private: + class ZeroSizeAllocator : public Allocator { + public: + explicit ZeroSizeAllocator(platform::Place place) : place_(place) {} + bool IsAllocThreadSafe() const override { return true; } + + protected: + Allocation* AllocateImpl(size_t size) override { + return new Allocation(nullptr, 0, place_); + } + void FreeImpl(Allocation* allocation) override { delete allocation; } + + private: + platform::Place place_; + }; + + const AllocatorMap& GetAllocatorMap() { +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { + auto id = platform::CUDAGraph::CapturingID(); + auto iter = cuda_graph_allocator_map_.find(id); + PADDLE_ENFORCE_NE( + iter, cuda_graph_allocator_map_.end(), + platform::errors::PermissionDenied( + "No memory pool is prepared for CUDA Graph capturing.")); + return iter->second->allocators_; + } else { + return allocators_; } +#else + return allocators_; #endif } @@ -256,14 +362,108 @@ class AllocatorFacadePrivate { std::make_shared(platform::CUDAPinnedPlace()); } + void InitStreamSafeCUDAAllocator(platform::CUDAPlace p, gpuStream_t stream) { + PADDLE_ENFORCE_EQ( + strategy_, AllocatorStrategy::kAutoGrowth, + platform::errors::Unimplemented( + "Only support auto-growth strategey for StreamSafeCUDAAllocator, " + "the allocator strategy %d is unsupported for multi-stream", + static_cast(strategy_))); + VLOG(9) << "Init CUDA allocator for stream " << stream << " in place " << p; + std::lock_guard lock_guard(cuda_allocators_lock_); + try { + GetAllocator(p, stream); + VLOG(9) << "Other thread had build a allocator for stream " << stream + << " in place " << p; + } catch (platform::EnforceNotMet&) { + InitAutoGrowthCUDAAllocator(p, stream); + WrapStreamSafeCUDAAllocator(p, stream); + WrapCUDARetryAllocator(p, stream, FLAGS_gpu_allocator_retry_time); + } catch (...) { + throw; + } + } + void InitNaiveBestFitCUDAAllocator(platform::CUDAPlace p) { allocators_[p] = std::make_shared(p); } - void InitThreadLocalCUDAAllocator(platform::CUDAPlace p) { - allocators_[p] = std::make_shared(p); + void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, gpuStream_t stream) { +#if defined(PADDLE_WITH_HIP) + auto cuda_allocator = std::make_shared(p); + cuda_allocators_[p][stream] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk_); +#endif + +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 10020 + CUdevice device; + int val; + try { + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId())); + + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuDeviceGetAttribute( + &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, + device)); + } catch (...) { + val = 0; + } + + if (val > 0 && FLAGS_use_virtual_memory_auto_growth) { + auto cuda_allocator = std::make_shared(p); + cuda_allocators_[p][stream] = + std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), p); + } else { + auto cuda_allocator = std::make_shared(p); + cuda_allocators_[p][stream] = + std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), + allow_free_idle_chunk_); + } +#else + auto cuda_allocator = std::make_shared(p); + auto alignment = platform::GpuMinChunkSize(); + bool need_addr_align = true; + // NOTE: sometimes, since cuda runtime can not be forked, calling any cuda + // API in that case may got cuda error(3), i.e., + // cudaErrorInitializationError. And, the CUDAAllocator is only initialized + // but not really used. + // Here, the try-catch block is added to handle the case that + // GetDeviceProperties() may failed in the multiple process(for example, in + // dataloader with num_worker > 0) + try { + const auto& prop = platform::GetDeviceProperties(p.GetDeviceId()); + need_addr_align = prop.textureAlignment < alignment; + VLOG(4) << "GetDeviceProperties ok, textureAlignment: " + << prop.textureAlignment + << ", set need_addr_align=" << need_addr_align; + } catch (...) { + need_addr_align = true; + VLOG(4) << "GetDeviceProperties failed, set need_addr_align=true"; + } + // The address returned is aligned already, + // ref: + // https://stackoverflow.com/questions/14082964/cuda-alignment-256bytes-seriously/14083295#14083295 + std::shared_ptr underlying_allocator{nullptr}; + if (need_addr_align) { + VLOG(10) << "use AlignedAllocator with alignment: " << alignment; + underlying_allocator = + std::make_shared(underlying_allocator, alignment); + } else { + VLOG(10) << "not use AlignedAllocator with alignment: " << alignment; + underlying_allocator = cuda_allocator; + } + + cuda_allocators_[p][stream] = std::make_shared( + underlying_allocator, alignment, 0, allow_free_idle_chunk_); +#endif +#endif } + // NOTE(Ruibiao): Old single-stream version, will be removed later void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, bool allow_free_idle_chunk) { #if defined(PADDLE_WITH_HIP) @@ -337,6 +537,37 @@ class AllocatorFacadePrivate { #endif #endif } + + void InitThreadLocalCUDAAllocator(platform::CUDAPlace p) { + allocators_[p] = std::make_shared(p); + } + + void WrapStreamSafeCUDAAllocator(platform::CUDAPlace p, gpuStream_t stream) { + const std::shared_ptr& underlying_allocator = + GetAllocator(p, stream); + cuda_allocators_[p][stream] = std::make_shared( + underlying_allocator, p, stream); + } + + void WrapCUDARetryAllocator(platform::CUDAPlace p, gpuStream_t stream, + size_t retry_time) { + PADDLE_ENFORCE_GT( + retry_time, 0, + platform::errors::InvalidArgument( + "Retry time should be larger than 0, but got %d", retry_time)); + std::shared_ptr allocator = GetAllocator(p, stream); + allocator = std::make_shared(allocator, retry_time); + } + + static void CheckCUDAAllocThreadSafe(const CUDAAllocatorMap& allocators) { + for (auto& place_pair : allocators) { + for (auto& stream_pair : place_pair.second) { + PADDLE_ENFORCE_EQ(stream_pair.second->IsAllocThreadSafe(), true, + platform::errors::InvalidArgument( + "Public allocators must be thread safe")); + } + } + } #endif #ifdef PADDLE_WITH_XPU @@ -354,25 +585,28 @@ class AllocatorFacadePrivate { allocators_[platform::NPUPinnedPlace()] = std::make_shared(); } - #endif - class ZeroSizeAllocator : public Allocator { - public: - explicit ZeroSizeAllocator(platform::Place place) : place_(place) {} - - bool IsAllocThreadSafe() const override { return true; } - - protected: - Allocation* AllocateImpl(size_t size) override { - return new Allocation(nullptr, 0, place_); + void InitSystemAllocators() { + if (!system_allocators_.empty()) return; + system_allocators_[platform::CPUPlace()] = std::make_shared(); +#ifdef PADDLE_WITH_XPU + int device_count = platform::GetXPUDeviceCount(); + for (int i = 0; i < device_count; ++i) { + platform::XPUPlace p(i); + system_allocators_[p] = std::make_shared(p); } - - void FreeImpl(Allocation* allocation) override { delete allocation; } - - private: - platform::Place place_; - }; +#endif +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + system_allocators_[platform::CUDAPinnedPlace()] = + std::make_shared(); + int device_count = platform::GetCUDADeviceCount(); + for (int i = 0; i < device_count; ++i) { + platform::CUDAPlace p(i); + system_allocators_[p] = std::make_shared(p); + } +#endif + } void InitZeroSizeAllocators() { if (!zero_size_allocators_.empty()) return; @@ -415,8 +649,14 @@ class AllocatorFacadePrivate { CheckAllocThreadSafe(allocators_); CheckAllocThreadSafe(zero_size_allocators_); CheckAllocThreadSafe(system_allocators_); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (FLAGS_use_stream_safe_cuda_allocator) { + CheckCUDAAllocThreadSafe(cuda_allocators_); + } +#endif } + // NOTE(Ruibiao): Old single-stream version, will be removed later void WrapCUDARetryAllocator(size_t retry_time) { PADDLE_ENFORCE_GT( retry_time, 0, @@ -429,53 +669,22 @@ class AllocatorFacadePrivate { } } -#ifdef PADDLE_WITH_CUDA - - public: - void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { - PADDLE_ENFORCE_EQ(strategy_, AllocatorStrategy::kAutoGrowth, - platform::errors::InvalidArgument( - "CUDA Graph is only supported when the " - "FLAGS_allocator_strategy=\"auto_growth\", but got " - "FLAGS_allocator_strategy=\"%s\"", - FLAGS_allocator_strategy)); - auto& allocator = cuda_graph_allocator_map_[id]; - PADDLE_ENFORCE_EQ( - allocator.get(), nullptr, - platform::errors::InvalidArgument( - "The memory pool of the CUDA Graph with ID %d have been prepared.", - id)); - allocator.reset( - new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false)); - for (auto& item : allocator->allocators_) { - auto& old_allocator = item.second; - old_allocator = CUDAGraphAllocator::Create(old_allocator); - } - VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id; - } - - void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { - auto iter = cuda_graph_allocator_map_.find(id); - PADDLE_ENFORCE_NE(iter, cuda_graph_allocator_map_.end(), - platform::errors::InvalidArgument( - "Cannot find CUDA Graph with ID = %d", id)); - cuda_graph_allocator_map_.erase(iter); - VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id; - } -#endif - - private: - AllocatorMap allocators_; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // a standalone CUDA allocator to support multi-stream GC in new executor + CUDAAllocatorMap cuda_allocators_; + gpuStream_t default_stream_; + SpinLock cuda_allocators_lock_; #ifdef PADDLE_WITH_CUDA std::unordered_map> cuda_graph_allocator_map_; +#endif #endif AllocatorStrategy strategy_; - + AllocatorMap allocators_; static AllocatorMap zero_size_allocators_; static AllocatorMap system_allocators_; + bool allow_free_idle_chunk_; }; - AllocatorFacadePrivate::AllocatorMap AllocatorFacadePrivate::zero_size_allocators_; AllocatorFacadePrivate::AllocatorMap AllocatorFacadePrivate::system_allocators_; @@ -491,6 +700,18 @@ AllocatorFacade& AllocatorFacade::Instance() { return instance; } +const std::shared_ptr& AllocatorFacade::GetAllocator( + const platform::Place& place) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) && + FLAGS_use_system_allocator == false) { + return m_->GetAllocator(BOOST_GET_CONST(platform::CUDAPlace, place), + m_->GetDefaultStream()); + } +#endif + return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); +} + std::shared_ptr AllocatorFacade::AllocShared( const platform::Place& place, size_t size) { return std::shared_ptr(Alloc(place, size)); @@ -498,17 +719,80 @@ std::shared_ptr AllocatorFacade::AllocShared( AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, size_t size) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) && + size > 0 && FLAGS_use_system_allocator == false) { + return Alloc(BOOST_GET_CONST(platform::CUDAPlace, place), size, + m_->GetDefaultStream()); + } +#endif return m_->GetAllocator(place, size)->Allocate(size); } uint64_t AllocatorFacade::Release(const platform::Place& place) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) && + FLAGS_use_system_allocator == false) { + return Release(BOOST_GET_CONST(platform::CUDAPlace, place), + m_->GetDefaultStream()); + } +#endif return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1) ->Release(place); } -const std::shared_ptr& AllocatorFacade::GetAllocator( - const platform::Place& place) { - return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +std::shared_ptr AllocatorFacade::AllocShared( + const platform::CUDAPlace& place, size_t size, const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ( + FLAGS_use_stream_safe_cuda_allocator, true, + platform::errors::Unimplemented( + "StreamSafeCUDAAllocator is disabled, you should not call this " + "multi-stream 'AllocaShared' function. " + "To enable it, you can enter 'export " + "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + return std::shared_ptr(Alloc(place, size, stream)); +} + +AllocationPtr AllocatorFacade::Alloc(const platform::CUDAPlace& place, + size_t size, const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ( + FLAGS_use_stream_safe_cuda_allocator, true, + platform::errors::Unimplemented( + "StreamSafeCUDAAllocator is disabled, you should not call this " + "multi-stream 'Alloca' function. " + "To enable it, you can enter 'export " + "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + if (LIKELY(size > 0 && FLAGS_use_system_allocator == false)) { + return m_->GetAllocator(place, stream, /* creat_if_not_found = */ true) + ->Allocate(size); + } else { + return m_->GetAllocator(place, size)->Allocate(size); + } +} + +uint64_t AllocatorFacade::Release(const platform::CUDAPlace& place, + const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ( + FLAGS_use_stream_safe_cuda_allocator, true, + platform::errors::Unimplemented( + "StreamSafeCUDAAllocator is disabled, you should not call this " + "multi-stream 'Release' function. " + "To enable it, you can enter 'export " + "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + return m_->GetAllocator(place, stream)->Release(place); +} + +void AllocatorFacade::RecordStream(Allocation* allocation, + const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ( + FLAGS_use_stream_safe_cuda_allocator, true, + platform::errors::Unimplemented( + "StreamSafeCUDAAllocator is disabled, you should not call this " + "'RecordStream' function. " + "To enable it, you can enter 'export " + "FLAGS_use_stream_safe_cuda_allocator=true' in the terminal.")); + m_->RecordStream(allocation, stream); } #ifdef PADDLE_WITH_CUDA @@ -520,7 +804,7 @@ void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { return m_->RemoveMemoryPoolOfCUDAGraph(id); } #endif - +#endif } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 8d889ec38e..4cd8b4e91e 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -26,6 +26,7 @@ namespace paddle { namespace memory { namespace allocation { + #ifdef PADDLE_WITH_ASCEND_CL using NPUPinnedAllocator = paddle::memory::allocation::NPUPinnedAllocator; #endif @@ -40,26 +41,34 @@ using NPUPinnedAllocator = paddle::memory::allocation::NPUPinnedAllocator; class AllocatorFacadePrivate; class AllocatorFacade { public: - ~AllocatorFacade(); AllocatorFacade(const AllocatorFacade& o) = delete; const AllocatorFacade& operator=(const AllocatorFacade& o) = delete; + ~AllocatorFacade(); static AllocatorFacade& Instance(); + const std::shared_ptr& GetAllocator(const platform::Place& place); + // Allocate a shared allocation. std::shared_ptr AllocShared(const platform::Place& place, size_t size); - // Allocate a unique allocation. AllocationPtr Alloc(const platform::Place& place, size_t size); - // Release unused memory pool. uint64_t Release(const platform::Place& place); - const std::shared_ptr& GetAllocator(const platform::Place& place); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + std::shared_ptr AllocShared(const platform::CUDAPlace& place, + size_t size, + const gpuStream_t& stream); + AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size, + const gpuStream_t& stream); + uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream); + void RecordStream(Allocation* allocation, const gpuStream_t& stream); #ifdef PADDLE_WITH_CUDA void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id); void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id); +#endif #endif // TODO(yy): Allocate a Copy-On-Write allocation? diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc new file mode 100644 index 0000000000..b2e13af6ef --- /dev/null +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace memory { +namespace allocation { + +StreamSafeCUDAAllocation::StreamSafeCUDAAllocation( + AllocationPtr underlying_allocation, gpuStream_t owning_stream) + : Allocation(underlying_allocation->ptr(), underlying_allocation->size(), + underlying_allocation->place()), + underlying_allocation_(std::move(underlying_allocation)), + owning_stream_(owning_stream), + recorded_streams_(std::make_shared>()) {} + +void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) { + VLOG(8) << "Record stream " << stream << " to " << ptr(); + if (stream == owning_stream_) { + return; + } + std::lock_guard lock_guard(spin_lock_); + recorded_streams_->insert(stream); +} + +std::shared_ptr> +StreamSafeCUDAAllocation::GetRecordedStreams() { + return recorded_streams_; +} + +StreamSafeCUDAAllocator::StreamSafeCUDAAllocator( + const std::shared_ptr& underlying_allocator, + const platform::CUDAPlace& place, const gpuStream_t default_stream) + : underlying_allocator_(underlying_allocator), + place_(place), + default_stream_(default_stream) { + std::lock_guard lock_guard(allocators_map_lock_); + allocators_map_[place].emplace_back(this); +} + +StreamSafeCUDAAllocator::~StreamSafeCUDAAllocator() { + std::lock_guard lock_guard(allocators_map_lock_); + std::vector& allocators = allocators_map_[place_]; + allocators.erase(std::remove(allocators.begin(), allocators.end(), this), + allocators.end()); +} + +bool StreamSafeCUDAAllocator::IsAllocThreadSafe() const { return true; } + +Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) { + ProcessEventsAndFree(); + AllocationPtr underlying_allocation; + try { + underlying_allocation = underlying_allocator_->Allocate(size); + } catch (BadAlloc&) { + VLOG(9) << "Allocation failed when allocating " << size << " bytes"; + uint64_t release_size = ReleaseImpl(place_); + VLOG(9) << "Release " << release_size << " bytes memory from all streams"; + try { + underlying_allocation = underlying_allocator_->Allocate(size); + } catch (...) { + VLOG(9) << "Still allocation failed after release memory"; + throw; + } + } catch (...) { + throw; + } + + StreamSafeCUDAAllocation* allocation = new StreamSafeCUDAAllocation( + std::move(underlying_allocation), default_stream_); + return allocation; +} + +void StreamSafeCUDAAllocator::FreeImpl(Allocation* allocation) { + if (dynamic_cast(allocation) + ->GetRecordedStreams() + ->empty()) { + delete allocation; + } else { + std::lock_guard lock_guard(outstanding_events_map_lock_); + FreeStreamSafeCUDAAllocation(allocation); + } +} + +uint64_t StreamSafeCUDAAllocator::ReleaseImpl(const platform::Place& place) { + std::lock_guard lock_guard(allocators_map_lock_); + std::vector& allocators = + allocators_map_[BOOST_GET_CONST(platform::CUDAPlace, place)]; + uint64_t release_size = 0; + for (StreamSafeCUDAAllocator* allocator : allocators) { + release_size += allocator->ProcessEventsAndFreeWithRelease(); + } + return release_size; +} + +void StreamSafeCUDAAllocator::CreateEventForAllRecordedStream( + std::set* recorded_streams, + std::deque* outstanding_events) { + for (gpuStream_t stream : *recorded_streams) { + gpuEvent_t event; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS( + hipEventCreateWithFlags(&event, hipEventDisableTiming)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, stream)); +#endif + outstanding_events->emplace_back(event); + VLOG(9) << "Record event " << event << " in stream " << stream; + } + recorded_streams->clear(); +} + +void StreamSafeCUDAAllocator::FreeStreamSafeCUDAAllocation( + Allocation* allocation) { + std::deque& outstanding_events = + outstanding_events_map_[allocation]; + CreateEventForAllRecordedStream( + dynamic_cast(allocation) + ->GetRecordedStreams() + .get(), + &outstanding_events); + if (!outstanding_events.empty()) { + VLOG(8) << allocation->ptr() << " is not ready to free"; + return; + } + + VLOG(8) << "Free " << allocation->ptr(); + outstanding_events_map_.erase(allocation); + delete allocation; +} + +void StreamSafeCUDAAllocator::ProcessEventsAndFree() { + std::lock_guard lock_guard(outstanding_events_map_lock_); + for (auto map_it = outstanding_events_map_.begin(); + map_it != outstanding_events_map_.end();) { + std::deque& outstanding_events = map_it->second; + VLOG(10) << "Check " << outstanding_events.size() + << " outstanding events for " << map_it->first->ptr(); + auto deque_it = outstanding_events.begin(); + while (deque_it != outstanding_events.end()) { +#ifdef PADDLE_WITH_CUDA + gpuError_t err = cudaEventQuery(*deque_it); + if (err == cudaErrorNotReady) { + VLOG(10) << "Event " << *deque_it << " for " << map_it->first->ptr() + << " is not completed"; + outstanding_events.erase(outstanding_events.begin(), deque_it); + break; + } + PADDLE_ENFORCE_CUDA_SUCCESS(err); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(*deque_it)); +#else + gpuError_t err = hipEventQuery(*deque_it); + if (err == hipErrorNotReady) { + VLOG(10) << "Event " << *deque_it << " for " << map_it->first->ptr() + << " is not completed"; + // Erase the completded event before "deque_it" + outstanding_events.erase(outstanding_events.begin(), deque_it); + break; + } + PADDLE_ENFORCE_CUDA_SUCCESS(err); + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(*deque_it)); +#endif + ++deque_it; + } + + if (deque_it == outstanding_events.end()) { + outstanding_events.clear(); + Allocation* allocation = map_it->first; + // "map_it" may be invalid after calling FreeStreamSafeCUDAAllocation + auto next_it = ++map_it; + FreeStreamSafeCUDAAllocation(allocation); + map_it = next_it; + } else { + ++map_it; + } + } +} + +uint64_t StreamSafeCUDAAllocator::ProcessEventsAndFreeWithRelease() { + ProcessEventsAndFree(); + return underlying_allocator_->Release(place_); +} + +std::map> + StreamSafeCUDAAllocator::allocators_map_; +SpinLock StreamSafeCUDAAllocator::allocators_map_lock_; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h new file mode 100644 index 0000000000..a516558228 --- /dev/null +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h @@ -0,0 +1,83 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef PADDLE_WITH_CUDA +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/allocation/spin_lock.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class StreamSafeCUDAAllocation : public Allocation { + public: + StreamSafeCUDAAllocation(AllocationPtr underlying_allocation, + gpuStream_t owning_stream); + void RecordStream(gpuStream_t stream); + std::shared_ptr> GetRecordedStreams(); + + private: + AllocationPtr underlying_allocation_; + gpuStream_t owning_stream_; + std::shared_ptr> recorded_streams_; + SpinLock spin_lock_; +}; + +class StreamSafeCUDAAllocator : public Allocator { + public: + StreamSafeCUDAAllocator( + const std::shared_ptr &underlying_allocator, + const platform::CUDAPlace &place, const gpuStream_t default_stream); + ~StreamSafeCUDAAllocator(); + bool IsAllocThreadSafe() const override; + + protected: + Allocation *AllocateImpl(size_t size) override; + void FreeImpl(Allocation *allocation) override; + uint64_t ReleaseImpl(const platform::Place &place) override; + + private: + void CreateEventForAllRecordedStream( + std::set *recorded_streams, + std::deque *outstanding_events); + void FreeStreamSafeCUDAAllocation(Allocation *allocation); + void ProcessEventsAndFree(); + uint64_t ProcessEventsAndFreeWithRelease(); + + static std::map> + allocators_map_; + static SpinLock allocators_map_lock_; + + std::shared_ptr underlying_allocator_; + platform::CUDAPlace place_; + gpuStream_t default_stream_; + std::map> outstanding_events_map_; + SpinLock outstanding_events_map_lock_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc index 078e841068..4921b87ccd 100644 --- a/paddle/fluid/memory/malloc.cc +++ b/paddle/fluid/memory/malloc.cc @@ -20,18 +20,40 @@ limitations under the License. */ namespace paddle { namespace memory { -std::shared_ptr AllocShared(const platform::Place &place, +std::shared_ptr AllocShared(const platform::Place& place, size_t size) { return allocation::AllocatorFacade::Instance().AllocShared(place, size); } -AllocationPtr Alloc(const platform::Place &place, size_t size) { +AllocationPtr Alloc(const platform::Place& place, size_t size) { return allocation::AllocatorFacade::Instance().Alloc(place, size); } -uint64_t Release(const platform::Place &place) { +uint64_t Release(const platform::Place& place) { return allocation::AllocatorFacade::Instance().Release(place); } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +std::shared_ptr AllocShared(const platform::CUDAPlace& place, + size_t size, + const gpuStream_t& stream) { + return allocation::AllocatorFacade::Instance().AllocShared(place, size, + stream); +} + +AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size, + const gpuStream_t& stream) { + return allocation::AllocatorFacade::Instance().Alloc(place, size, stream); +} + +uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream) { + return allocation::AllocatorFacade::Instance().Release(place, stream); +} + +void RecordStream(Allocation* allocation, const gpuStream_t& stream) { + return allocation::AllocatorFacade::Instance().RecordStream(allocation, + stream); +} +#endif } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/malloc.h b/paddle/fluid/memory/malloc.h index 3b8d07548e..2aa9fbe6ad 100644 --- a/paddle/fluid/memory/malloc.h +++ b/paddle/fluid/memory/malloc.h @@ -40,5 +40,18 @@ extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size); extern uint64_t Release(const platform::Place& place); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +extern std::shared_ptr AllocShared(const platform::CUDAPlace& place, + size_t size, + const gpuStream_t& stream); + +extern AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size, + const gpuStream_t& stream); + +extern uint64_t Release(const platform::CUDAPlace& place, + const gpuStream_t& stream); + +void RecordStream(Allocation* allocation, const gpuStream_t& stream); +#endif } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu new file mode 100644 index 0000000000..6a5818fd96 --- /dev/null +++ b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu @@ -0,0 +1,229 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif + +#ifdef PADDLE_WITH_HIP +#include +#endif + +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace memory { + +__global__ void add_kernel(int *x, int n) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + for (int i = tid; i < n; i += blockDim.x * gridDim.x) { + atomicAdd(x + i, tid); + } +} + +class StreamSafeCUDAAllocTest : public ::testing::Test { + protected: + void SetUp() override { + place_ = platform::CUDAPlace(); + stream_num_ = 64; + grid_num_ = 1; + block_num_ = 64; + data_num_ = 64; + default_stream = nullptr; + + streams_.reserve(stream_num_); + streams_.emplace_back(default_stream); + for (size_t i = 1; i < stream_num_; ++i) { + gpuStream_t stream; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream)); +#endif + streams_.emplace_back(stream); + } + + for (size_t i = 0; i < stream_num_; ++i) { + size_t allocation_size = data_num_ * sizeof(int); + std::shared_ptr allocation = + AllocShared(place_, allocation_size, streams_[i]); +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemset(allocation->ptr(), 0, allocation->size())); +#else + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemset(allocation->ptr(), 0, allocation->size())); +#endif + allocations_.emplace_back(allocation); + } + } + + void SingleStreamRun(size_t idx) { + for (size_t i = 0; i < stream_num_; ++i) { + int *x = reinterpret_cast(allocations_[i]->ptr()); + add_kernel<<>>(x, data_num_); + if (i != idx) { + RecordStream(allocations_[i].get(), streams_[idx]); + } + } + } + + void MultiStreamRun() { + for (int i = 0; i < stream_num_; ++i) { + SingleStreamRun(i); + } + allocations_.clear(); // fast_gc + } + + void MultiThreadMUltiStreamRun() { + std::vector threads; + for (size_t i = 0; i < stream_num_; ++i) { + threads.push_back( + std::thread(&StreamSafeCUDAAllocTest::SingleStreamRun, this, i)); + } + for (size_t i = 0; i < stream_num_; ++i) { + threads[i].join(); + } + allocations_.clear(); // fast_gc + } + + void CheckResult() { + auto host_x = std::unique_ptr(new int[data_num_]); + size_t thread_num = grid_num_ * block_num_; + for (int i = 0; i < stream_num_; ++i) { +// tricky code, the allocations are still accessible even though +// allocations_.clear() has been called +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpy(host_x.get(), allocations_[i]->ptr(), + data_num_ * sizeof(int), cudaMemcpyDeviceToHost)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemcpy(host_x.get(), allocations_[i]->ptr(), + data_num_ * sizeof(int), hipMemcpyDeviceToHost)); +#endif + for (int j = 0; j < data_num_; ++j) { + EXPECT_TRUE(host_x[j] == (j % thread_num) * stream_num_); + } + } + } + + void TearDown() override { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceSynchronize()); +#endif + for (gpuStream_t stream : streams_) { + Release(place_, stream); + } + + for (size_t i = 1; i < stream_num_; ++i) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(streams_[i])); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamDestroy(streams_[i])); +#endif + } + + uint64_t cuda_malloc_size = + platform::RecordedCudaMallocSize(place_.GetDeviceId()); + ASSERT_EQ(cuda_malloc_size, 0) << "Found " << cuda_malloc_size + << " bytes memory that not released yet," + << " there may be a memory leak problem"; + } + + size_t stream_num_; + size_t grid_num_; + size_t block_num_; + size_t data_num_; + platform::CUDAPlace place_; + gpuStream_t default_stream; + std::vector streams_; + std::vector> allocations_; +}; + +TEST_F(StreamSafeCUDAAllocTest, CUDAMutilStreamTest) { + MultiStreamRun(); + CheckResult(); +} + +TEST_F(StreamSafeCUDAAllocTest, CUDAMutilThreadMutilStreamTest) { + MultiThreadMUltiStreamRun(); + CheckResult(); +} + +TEST(StreamSafeCUDAAllocInterfaceTest, AllocInterfaceTest) { + platform::CUDAPlace place = platform::CUDAPlace(); + size_t alloc_size = 256; + + std::shared_ptr allocation_implicit_stream = + AllocShared(place, alloc_size); + EXPECT_GE(allocation_implicit_stream->size(), alloc_size); + + void *address = allocation_implicit_stream->ptr(); + allocation_implicit_stream.reset(); + + gpuStream_t default_stream = nullptr; + allocation::AllocationPtr allocation_unique = + Alloc(place, alloc_size, default_stream); + EXPECT_GE(allocation_unique->size(), alloc_size); + EXPECT_EQ(allocation_unique->ptr(), address); +} + +TEST(StreamSafeCUDAAllocRetryTest, RetryTest) { + platform::CUDAPlace place = platform::CUDAPlace(); + gpuStream_t stream1, stream2; +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream1)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream2)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream1)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream2)); +#endif + size_t available_size = platform::GpuAvailableMemToAlloc(); + // alloc_size < available_size < 2 * alloc_size + size_t alloc_size = available_size / 4 * 3; + + std::shared_ptr allocation1 = + AllocShared(place, alloc_size, stream1); + std::shared_ptr allocation2; + + std::thread th([&allocation2, &place, &stream2, alloc_size]() { + std::this_thread::sleep_for(std::chrono::seconds(1)); + allocation2 = AllocShared(place, alloc_size, stream2); + }); + allocation1.reset(); // free but not release + th.join(); + EXPECT_GE(allocation2->size(), alloc_size); + allocation2.reset(); + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceSynchronize()); +#endif + + Release(place, stream1); + Release(place, stream2); +} + +} // namespace memory +} // namespace paddle -- GitLab