From a1ec1d5a4933a8242c2b1b14deafd449de506e26 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 8 Nov 2021 19:09:17 +0800 Subject: [PATCH] Use cuda virtual memory management and merge blocks (#36189) * Use cuda virtual memory management and merge blocks, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * window dll, test=develop * fix cuda error of CUDA_ERROR_NOT_INITIALIZED, test=develop * use autogrowthv2 for system allocator, test=develop * remove ~CUDAVirtualMemAllocator(), test=develop * refine, test=develop * fix cuda error of CUDA_ERROR_NOT_INITIALIZED, test=develop * fix cuda error of CUDA_ERROR_NOT_INITIALIZED, test=develop * fix bug, test=develop * revert system allocator, test =develop * revert multiprocessing, test=develop * fix AutoGrowthBestFitAllocatorV2 mutxt, test=develop * catch cudaErrorInitializationError when create allocator, test=develop * fix cuMemSetAccess use, test=develop * refine cuda api use, test=develop * refine, test=develop * for test, test=develop * for test, test=develop * switch to v2, test=develop * refine virtual allocator, test=develop * Record cuMemCreate and cuMemRelease, test=develop * refine, test=develop * avoid out of bounds, test=develop * rename allocator, test=develop * refine, test=develop * use PADDLE_ENFORCE_CUDA_SUCCESS, test=develop * for test,test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop --- paddle/fluid/memory/allocation/CMakeLists.txt | 10 +- .../memory/allocation/allocator_facade.cc | 44 +++ .../allocation/cuda_virtual_mem_allocator.cc | 225 ++++++++++++++++ .../allocation/cuda_virtual_mem_allocator.h | 62 +++++ ...l_memory_auto_growth_best_fit_allocator.cc | 254 ++++++++++++++++++ ...al_memory_auto_growth_best_fit_allocator.h | 84 ++++++ paddle/fluid/platform/dynload/CMakeLists.txt | 4 +- paddle/fluid/platform/dynload/cuda_driver.cc | 3 + paddle/fluid/platform/dynload/cuda_driver.h | 18 +- .../fluid/platform/dynload/dynamic_loader.cc | 8 + paddle/fluid/platform/enforce.h | 12 + paddle/fluid/platform/external_error.proto | 1 + paddle/fluid/platform/gpu_info.cc | 45 ++++ paddle/fluid/platform/gpu_info.h | 14 + ...est_softmax_mask_fuse_upper_triangle_op.py | 10 +- 15 files changed, 785 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc create mode 100644 paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h create mode 100644 paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc create mode 100644 paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 4aa1900f53f..58979d6c3e1 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -18,6 +18,9 @@ if (WITH_GPU) nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator) + if(CUDA_VERSION GREATER_EQUAL 10.2) + nv_library(cuda_virtual_mem_allocator SRCS cuda_virtual_mem_allocator.cc DEPS dynload_cuda) + endif() endif() if (WITH_ROCM) @@ -36,6 +39,9 @@ cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) if (WITH_GPU OR WITH_ROCM) set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator) + if(CUDA_VERSION GREATER_EQUAL 10.2) + list(APPEND AllocatorFacadeDeps cuda_virtual_mem_allocator) + endif() elseif(WITH_XPU) set(AllocatorFacadeDeps xpu_info) elseif(WITH_ASCEND) @@ -72,7 +78,7 @@ else() cpu_allocator) endif() -list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator aligned_allocator retry_allocator buffered_allocator naive_best_fit_allocator auto_growth_best_fit_allocator best_fit_allocator) +list(APPEND AllocatorFacadeDeps cpu_allocator locked_allocator aligned_allocator retry_allocator buffered_allocator naive_best_fit_allocator auto_growth_best_fit_allocator virtual_memory_auto_growth_best_fit_allocator best_fit_allocator) if (WITH_ASCEND_CL) list(APPEND AllocatorFacadeDeps npu_pinned_allocator) @@ -107,6 +113,8 @@ cc_library(auto_growth_best_fit_allocator SRCS auto_growth_best_fit_allocator.cc cc_test(auto_growth_best_fit_allocator_facade_test SRCS auto_growth_best_fit_allocator_facade_test.cc DEPS cpu_allocator auto_growth_best_fit_allocator) cc_test(auto_growth_best_fit_allocator_test SRCS auto_growth_best_fit_allocator_test.cc DEPS auto_growth_best_fit_allocator) +cc_library(virtual_memory_auto_growth_best_fit_allocator SRCS virtual_memory_auto_growth_best_fit_allocator.cc DEPS allocator aligned_allocator) + if(NOT WIN32) cc_library(mmap_allocator SRCS mmap_allocator.cc DEPS allocator) cc_test(mmap_allocator_test SRCS mmap_allocator_test.cc DEPS mmap_allocator allocator) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 93c9887913d..9da735636fc 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -33,6 +33,11 @@ #include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/platform/gpu_info.h" #endif +#if CUDA_VERSION >= 10020 +#include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" +#include "paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h" +#include "paddle/fluid/platform/dynload/cuda_driver.h" +#endif #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_graph.h" #endif @@ -51,6 +56,9 @@ PADDLE_DEFINE_EXPORTED_bool( "Whether to use system allocator to allocate CPU and GPU memory. " "Only used for unittests."); +PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, + "Use VirtualMemoryAutoGrowthBestFitAllocator."); + DECLARE_string(allocator_strategy); namespace paddle { @@ -258,6 +266,40 @@ class AllocatorFacadePrivate { void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, bool allow_free_idle_chunk) { +#if defined(PADDLE_WITH_HIP) + auto cuda_allocator = std::make_shared(p); + allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); +#endif + +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 10020 + CUdevice device; + int val; + try { + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuDeviceGet(&device, p.GetDeviceId())); + + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuDeviceGetAttribute( + &val, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, + device)); + } catch (...) { + val = 0; + } + + if (val > 0 && FLAGS_use_virtual_memory_auto_growth) { + auto cuda_allocator = std::make_shared(p); + allocators_[p] = + std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), p); + } else { + auto cuda_allocator = std::make_shared(p); + allocators_[p] = std::make_shared( + cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); + } + +#else auto cuda_allocator = std::make_shared(p); auto alignment = platform::GpuMinChunkSize(); bool need_addr_align = true; @@ -292,6 +334,8 @@ class AllocatorFacadePrivate { } allocators_[p] = std::make_shared( underlying_allocator, alignment, 0, allow_free_idle_chunk); +#endif +#endif } #endif diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc new file mode 100644 index 00000000000..ef64c3bdb35 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.cc @@ -0,0 +1,225 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif + +#include +#include "paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h" +#include "paddle/fluid/platform/enforce.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/dynload/cuda_driver.h" +#include "paddle/fluid/platform/gpu_info.h" +#endif +#if CUDA_VERSION >= 10020 + +namespace paddle { +namespace memory { +namespace allocation { + +CUDAVirtualMemAllocator::CUDAVirtualMemAllocator( + const platform::CUDAPlace& place) + : place_(place) { + CUmemAllocationProp prop = {}; + + // Setup the properties common for all the chunks + // The allocations will be device pinned memory. + // This property structure describes the physical location where the memory + // will be allocated via cuMemCreate allong with additional properties In this + // case, the allocation will be pinnded device memory local to a given device. + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = place.device; + prop_ = prop; + + // Prepare the access descriptor array indicating where and how the backings + // should be visible. + access_desc_.resize(platform::GetCUDADeviceCount()); + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { + if (place.device != dev_id) { + int capable = 0; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaDeviceCanAccessPeer(&capable, place.device, dev_id)); + if (!capable) { + continue; + } + } + // Specify which device we are adding mappings for. + access_desc_[dev_id].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc_[dev_id].location.id = dev_id; + + // Specify both read and write access. + access_desc_[dev_id].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + + // Get the minimum granularity needed for all devices + // (the max of the minimum granularity of each participating device) + granularity_ = 0; + for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { + size_t granularity; + prop.location.id = dev_id; + PADDLE_ENFORCE_CUDA_SUCCESS( + paddle::platform::dynload::cuMemGetAllocationGranularity( + &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + granularity_ = std::max(granularity, granularity_); + } + + size_t actual_avail, actual_total; + paddle::platform::CUDADeviceGuard guard(place.device); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); + + virtual_mem_size_ = AlignedSize(actual_total, granularity_); + + // Reserve the required contiguous virtual address space for the allocations + // The maximum video memory size we can apply for is the video memory size of + // GPU, + // so the virtual address space size we reserve is equal to the GPU video + // memory size + PADDLE_ENFORCE_CUDA_SUCCESS(paddle::platform::dynload::cuMemAddressReserve( + &virtual_mem_base_, virtual_mem_size_, 0, 0, 0)); + + virtual_mem_alloced_offset_ = 0; +} + +bool CUDAVirtualMemAllocator::IsAllocThreadSafe() const { return false; } + +void CUDAVirtualMemAllocator::FreeImpl(Allocation* allocation) { + PADDLE_ENFORCE_EQ( + BOOST_GET_CONST(platform::CUDAPlace, allocation->place()), place_, + platform::errors::PermissionDenied( + "GPU memory is freed in incorrect device. This may be a bug")); + + auto iter = virtual_2_physical_map_.find( + reinterpret_cast(allocation->ptr())); + if (iter == virtual_2_physical_map_.end()) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Can not find virtual memory address at %s", allocation->ptr())); + } + + int prev_id; + cudaGetDevice(&prev_id); + if (prev_id != place_.device) { + cudaSetDevice(place_.device); + } + + auto result = + paddle::platform::dynload::cuMemUnmap(iter->first, iter->second.second); + if (result != CUDA_ERROR_DEINITIALIZED) { + PADDLE_ENFORCE_CUDA_SUCCESS(result); + } + + if (result != CUDA_ERROR_DEINITIALIZED) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::RecordedCuMemRelease( + iter->second.first, iter->second.second, place_.device)); + } + + if (prev_id != place_.device) { + cudaSetDevice(prev_id); + } + + virtual_2_physical_map_.erase(iter); + + delete allocation; +} + +Allocation* CUDAVirtualMemAllocator::AllocateImpl(size_t size) { + size = AlignedSize(size, granularity_); + + CUdeviceptr ptr = virtual_mem_base_ + virtual_mem_alloced_offset_; + + if (ptr + size > virtual_mem_base_ + virtual_mem_size_) { + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on GPU Virtual Memory %d. " + "Cannot allocate %s memory on GPU Virtual Memory %d, %s memory has " + "been allocated and " + "available memory is only %s.\n\n" + "Please decrease the batch size of your model.\n\n", + place_.device, string::HumanReadableSize(size), place_.device, + string::HumanReadableSize(virtual_mem_alloced_offset_), + string::HumanReadableSize(virtual_mem_size_ - + virtual_mem_alloced_offset_), + place_.device)); + return nullptr; + } + + CUmemGenericAllocationHandle handle; + + paddle::platform::CUDADeviceGuard guard(place_.device); + + // Create physical memory backing allocation. + auto result = + platform::RecordedCuMemCreate(&handle, size, &prop_, 0, place_.device); + + if (result != CUDA_SUCCESS) { + if (result == CUDA_ERROR_OUT_OF_MEMORY) { + size_t actual_avail, actual_total; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemGetInfo(&actual_avail, &actual_total)); + size_t actual_allocated = actual_total - actual_avail; + + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on GPU %d. " + "Cannot allocate %s memory on GPU %d, %s memory has been allocated " + "and " + "available memory is only %s.\n\n" + "Please check whether there is any other process using GPU %d.\n" + "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" + "2. If no, please decrease the batch size of your model.\n\n", + place_.device, string::HumanReadableSize(size), place_.device, + string::HumanReadableSize(actual_allocated), + string::HumanReadableSize(actual_avail), place_.device)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(result); + } + return nullptr; + } + + // Assign the chunk to the appropriate VA range and release the handle. + // After mapping the memory, it can be referenced by virtual address. + // The allocation will be kept live until it is unmapped. + result = paddle::platform::dynload::cuMemMap(ptr, size, 0, handle, 0); + + if (result != CUDA_SUCCESS) { + platform::RecordedCuMemRelease(handle, size, place_.device); + PADDLE_ENFORCE_CUDA_SUCCESS(result); + return nullptr; + } + + // Apply the access descriptors to the whole VA range. + result = paddle::platform::dynload::cuMemSetAccess( + ptr, size, access_desc_.data(), access_desc_.size()); + + if (result != CUDA_SUCCESS) { + paddle::platform::dynload::cuMemUnmap(ptr, size); + platform::RecordedCuMemRelease(handle, size, place_.device); + PADDLE_ENFORCE_CUDA_SUCCESS(result); + return nullptr; + } + + virtual_2_physical_map_.emplace(ptr, std::make_pair(handle, size)); + + virtual_mem_alloced_offset_ += size; + + return new Allocation(reinterpret_cast(ptr), size, + platform::Place(place_)); +} + +} // namespace allocation +} // namespace memory +} // namespace paddle + +#endif diff --git a/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h new file mode 100644 index 00000000000..c51b56566bb --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_virtual_mem_allocator.h @@ -0,0 +1,62 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_CUDA +#include +#include "paddle/fluid/platform/cuda_device_guard.h" +#endif + +#include // NOLINT +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/place.h" + +#if CUDA_VERSION >= 10020 + +namespace paddle { +namespace memory { +namespace allocation { + +// Allocate memory using NVIDIA's virtual memory management technology +class CUDAVirtualMemAllocator : public Allocator { + public: + explicit CUDAVirtualMemAllocator(const platform::CUDAPlace& place); + + bool IsAllocThreadSafe() const override; + + protected: + void FreeImpl(Allocation* allocation) override; + Allocation* AllocateImpl(size_t size) override; + + private: + platform::CUDAPlace place_; + + CUdeviceptr virtual_mem_base_; + size_t virtual_mem_size_; + size_t virtual_mem_alloced_offset_; + size_t granularity_; + + CUmemAllocationProp prop_; + std::vector access_desc_; + + std::map> + virtual_2_physical_map_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle + +#endif diff --git a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc new file mode 100644 index 00000000000..5c7e8e2d933 --- /dev/null +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc @@ -0,0 +1,254 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/memory/allocation/aligned_allocator.h" +#include "paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h" + +namespace paddle { +namespace memory { +namespace allocation { + +bool NeedSplit(size_t block_size, size_t alignment, size_t allock_size) { + return block_size > (allock_size * 2) || + (block_size - allock_size) > alignment; +} + +VirtualMemoryAutoGrowthBestFitAllocator:: + VirtualMemoryAutoGrowthBestFitAllocator( + const std::shared_ptr &underlying_allocator, + size_t alignment, const platform::CUDAPlace &place) + : underlying_allocator_( + std::make_shared(underlying_allocator, alignment)), + alignment_(alignment), + place_(place) {} + +Allocation *VirtualMemoryAutoGrowthBestFitAllocator::AllocateImpl(size_t size) { + std::lock_guard guard(spinlock_); + size = AlignedSize(size, alignment_); + auto result = AllocFromFreeBlocks(size); + + if (!result) { + ExtendAndMerge(size); + result = AllocFromFreeBlocks(size); + } + + return result; +} + +void VirtualMemoryAutoGrowthBestFitAllocator::FreeImpl(Allocation *allocation) { + std::lock_guard guard(spinlock_); + auto block_it = static_cast(allocation)->block_it_; + TryMergeBlock2Blocks(block_it); + delete allocation; +} + +void VirtualMemoryAutoGrowthBestFitAllocator::TryMergeBlock2Blocks( + std::list::iterator block) { + if (block->ptr_ == all_blocks_.front().ptr_ && + block->ptr_ == all_blocks_.back().ptr_) { + block->is_free_ = true; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } else if (block->ptr_ == all_blocks_.front().ptr_) { + auto next = std::next(block); + if (next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == next->ptr_) { + // merge with next + block->size_ += next->size_; + block->is_free_ = true; + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + all_blocks_.erase(next); + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } else { + block->is_free_ = true; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } + } else if (block->ptr_ == all_blocks_.back().ptr_) { + auto pre = std::prev(block); + if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_) { + // merge with pre + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + pre->size_ += block->size_; + all_blocks_.erase(block); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else { + block->is_free_ = true; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } + } else { + auto pre = std::prev(block); + auto next = std::next(block); + if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == block->ptr_ && + !(next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == + next->ptr_)) { + // merge with pre + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + pre->size_ += block->size_; + all_blocks_.erase(block); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else if (next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == + next->ptr_ && + !(pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == + block->ptr_)) { + // merge with next + block->size_ += next->size_; + block->is_free_ = true; + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + all_blocks_.erase(next); + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } else if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == + block->ptr_ && + next->is_free_ && + reinterpret_cast(block->ptr_) + block->size_ == + next->ptr_) { + // merge with pre and next + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + pre->size_ += (block->size_ + next->size_); + all_blocks_.erase(block); + all_blocks_.erase(next); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else { + block->is_free_ = true; + free_blocks_.emplace(std::make_pair(block->size_, block->ptr_), block); + } + } +} + +void VirtualMemoryAutoGrowthBestFitAllocator::ExtendAndMerge(size_t size) { + void *ptr = nullptr; + + auto allocateptr = underlying_allocator_->Allocate(size); + ptr = allocateptr->ptr(); + size = allocateptr->size(); + allocations_.push_back(std::move(allocateptr)); // hold allocation + + if (all_blocks_.empty()) { + all_blocks_.push_back(Block(ptr, size, true)); + free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); + return; + } + for (auto block_it = all_blocks_.begin(); block_it != all_blocks_.end(); + ++block_it) { + if (block_it->ptr_ > ptr) { + if (block_it == all_blocks_.begin()) { + // insert to front + if (block_it->is_free_ && + reinterpret_cast(ptr) + size == block_it->ptr_) { + // merge with next + free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); + block_it->ptr_ = ptr; + block_it->size_ += size; + free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), + block_it); + } else { + // do not merge + all_blocks_.push_front(Block(ptr, size, true)); + free_blocks_.emplace(std::make_pair(size, ptr), all_blocks_.begin()); + } + } else { + // insert to middle + auto next = block_it; + auto pre = std::prev(block_it); + if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == ptr && + !(next->is_free_ && + reinterpret_cast(ptr) + size == next->ptr_)) { + // merge with pre + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + pre->size_ += size; + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + } else if (next->is_free_ && + reinterpret_cast(ptr) + size == next->ptr_ && + !(pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == + ptr)) { + // merge with next + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + next->ptr_ = ptr; + next->size_ += size; + free_blocks_.emplace(std::make_pair(next->size_, next->ptr_), next); + } else if (pre->is_free_ && + reinterpret_cast(pre->ptr_) + pre->size_ == ptr && + next->is_free_ && + reinterpret_cast(ptr) + size == next->ptr_) { + // merge with pre and next + free_blocks_.erase(std::make_pair(pre->size_, pre->ptr_)); + free_blocks_.erase(std::make_pair(next->size_, next->ptr_)); + pre->size_ += (size + next->size_); + free_blocks_.emplace(std::make_pair(pre->size_, pre->ptr_), pre); + all_blocks_.erase(next); + } else { + // do not merge + auto iter = all_blocks_.insert(next, Block(ptr, size, true)); + free_blocks_.emplace(std::make_pair(size, ptr), iter); + } + } + return; + } + } + + // insert to back + auto block_it = all_blocks_.end(); + block_it--; + if (block_it->is_free_ && + reinterpret_cast(block_it->ptr_) + block_it->size_ == ptr) { + // merge with pre + free_blocks_.erase(std::make_pair(block_it->size_, block_it->ptr_)); + block_it->size_ += size; + free_blocks_.emplace(std::make_pair(block_it->size_, block_it->ptr_), + block_it); + } else { + // do not merge + all_blocks_.push_back(Block(ptr, size, true)); + auto block_it = all_blocks_.end(); + block_it--; + free_blocks_.emplace(std::make_pair(size, ptr), block_it); + } +} + +Allocation *VirtualMemoryAutoGrowthBestFitAllocator::AllocFromFreeBlocks( + size_t size) { + auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr)); + if (iter != free_blocks_.end()) { + std::list::iterator block_it = iter->second; + free_blocks_.erase(iter); + if (NeedSplit(block_it->size_, alignment_, size)) { + size_t remaining_size = block_it->size_ - size; + auto remaining_free_block = all_blocks_.insert( + block_it, Block(block_it->ptr_, remaining_size, true)); + free_blocks_.emplace(std::make_pair(remaining_size, block_it->ptr_), + remaining_free_block); + block_it->ptr_ = + reinterpret_cast(block_it->ptr_) + remaining_size; + block_it->size_ = size; + } + + block_it->is_free_ = false; + return new BlockAllocation(block_it, place_); + } + + return nullptr; +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h new file mode 100644 index 00000000000..5171e5b3cd1 --- /dev/null +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h @@ -0,0 +1,84 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/allocation/spin_lock.h" + +namespace paddle { +namespace memory { +namespace allocation { + +struct Block { + Block(void *ptr, size_t size, bool is_free) + : ptr_(ptr), size_(size), is_free_(is_free) {} + + void *ptr_; + size_t size_; + bool is_free_; +}; + +struct BlockAllocation : public Allocation { + explicit BlockAllocation(const std::list::iterator &it, + platform::Place place) + : Allocation(it->ptr_, it->size_, place), block_it_(it) {} + + std::list::iterator block_it_; +}; + +/** + * Like AutoGrowthBestFitAllocator, VirtualMemoryAutoGrowthBestFitAllocator will + * gradually apply to GPU for video memory as the model uses more video memory. + * However, the difference is that VirtualMemoryAutoGrowthBestFitAllocator uses + * nviaid's virtual memory management technology and obtains the virtual memory + * address. If the video memory applied for twice is continuous, we can combine + * the two video memories later. This combination can greatly reduce + * fragmentation. + */ +class VirtualMemoryAutoGrowthBestFitAllocator : public Allocator { + public: + VirtualMemoryAutoGrowthBestFitAllocator( + const std::shared_ptr &underlying_allocator, size_t alignment, + const platform::CUDAPlace &place); + + bool IsAllocThreadSafe() const override { return true; } + + protected: + Allocation *AllocateImpl(size_t size) override; + + void FreeImpl(Allocation *allocation) override; + + private: + Allocation *AllocFromFreeBlocks(size_t size); + void ExtendAndMerge(size_t size); + void TryMergeBlock2Blocks(std::list::iterator iter); + + std::shared_ptr underlying_allocator_; + size_t alignment_; + + std::map, std::list::iterator> free_blocks_; + std::list all_blocks_; + std::list allocations_; + platform::Place place_; + SpinLock spinlock_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 6e90ccfc51e..b396caf54a4 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -11,8 +11,8 @@ if (WITH_ROCM) endif() # There is no macOS version of NCCL. -# Disable nvrtc and cuda_driver api on MacOS and Windows, and only do a early test on Linux. -if (NOT APPLE AND NOT WIN32) +# Disable nvrtc and cuda_driver api on MacOS, and only do a early test on Linux and Windows. +if (NOT APPLE) list(APPEND CUDA_SRCS nvrtc.cc cuda_driver.cc) if (WITH_NCCL) list(APPEND CUDA_SRCS nccl.cc) diff --git a/paddle/fluid/platform/dynload/cuda_driver.cc b/paddle/fluid/platform/dynload/cuda_driver.cc index 89a29bae7f3..6110e6b6ba9 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.cc +++ b/paddle/fluid/platform/dynload/cuda_driver.cc @@ -23,6 +23,9 @@ void* cuda_dso_handle = nullptr; #define DEFINE_WRAP(__name) DynLoad__##__name __name +#if CUDA_VERSION >= 10020 +CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP); +#endif CUDA_ROUTINE_EACH(DEFINE_WRAP); bool HasCUDADriver() { diff --git a/paddle/fluid/platform/dynload/cuda_driver.h b/paddle/fluid/platform/dynload/cuda_driver.h index 5799b084f5f..b5212c64cd1 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.h +++ b/paddle/fluid/platform/dynload/cuda_driver.h @@ -57,7 +57,23 @@ extern bool HasCUDADriver(); __macro(cuCtxCreate); \ __macro(cuCtxGetCurrent); \ __macro(cuDeviceGetCount); \ - __macro(cuDevicePrimaryCtxGetState) + __macro(cuDevicePrimaryCtxGetState); \ + __macro(cuDeviceGetAttribute); \ + __macro(cuDeviceGet) + +#if CUDA_VERSION >= 10020 +#define CUDA_ROUTINE_EACH_VVM(__macro) \ + __macro(cuMemGetAllocationGranularity); \ + __macro(cuMemAddressReserve); \ + __macro(cuMemCreate); \ + __macro(cuMemMap); \ + __macro(cuMemSetAccess); \ + __macro(cuMemUnmap); \ + __macro(cuMemRelease); \ + __macro(cuMemAddressFree) + +CUDA_ROUTINE_EACH_VVM(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); +#endif CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP); diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 1bfd48b1339..544c1c194d9 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -21,6 +21,10 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cupti_lib_path.h" #include "paddle/fluid/platform/enforce.h" +#if defined(_WIN32) +#include +#endif + DEFINE_string(cudnn_dir, "", "Specify path for loading libcudnn.so. For instance, " "/usr/local/cudnn/lib. If empty [default], dlopen " @@ -414,6 +418,10 @@ void* GetCUDADsoHandle() { return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.dylib", false); #elif defined(PADDLE_WITH_HIP) return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libamdhip64.so", false); +#elif defined(_WIN32) + char system32_dir[MAX_PATH]; + GetSystemDirectory(system32_dir, MAX_PATH); + return GetDsoHandleFromSearchPath(system32_dir, "nvcuda.dll"); #else return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcuda.so", false); #endif diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index a0e2dd5f7e3..bdb901f583e 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -714,6 +714,7 @@ DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN); DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS); DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER); DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT); +DEFINE_EXTERNAL_API_TYPE(CUresult, CUDA_SUCCESS, CU); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL); @@ -728,6 +729,7 @@ inline const char* GetErrorMsgUrl(T status) { details::ExternalApiType<__CUDA_STATUS_TYPE__>::kProtoType; switch (proto_type) { case platform::proto::ApiType::CUDA: + case platform::proto::ApiType::CU: return "https://docs.nvidia.com/cuda/cuda-runtime-api/" "group__CUDART__TYPES.html#group__CUDART__TYPES_" "1g3f51e3575c2178246db0a94a430e0038"; @@ -842,6 +844,7 @@ template std::string GetExternalErrorMsg(cudnnStatus_t); template std::string GetExternalErrorMsg(cublasStatus_t); template std::string GetExternalErrorMsg(cusolverStatus_t); template std::string GetExternalErrorMsg(cufftResult_t); +template std::string GetExternalErrorMsg(CUresult); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) template std::string GetExternalErrorMsg(ncclResult_t); #endif @@ -911,6 +914,15 @@ inline std::string build_nvidia_error_msg(cufftResult_t stat) { return sout.str(); } +/*************** CUresult ERROR ***************/ +inline bool is_error(CUresult stat) { return stat != CUDA_SUCCESS; } + +inline std::string build_nvidia_error_msg(CUresult stat) { + std::ostringstream sout; + sout << "CU error(" << stat << "). " << GetExternalErrorMsg(stat); + return sout.str(); +} + /**************** NCCL ERROR ****************/ #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) inline bool is_error(ncclResult_t nccl_result) { diff --git a/paddle/fluid/platform/external_error.proto b/paddle/fluid/platform/external_error.proto index cbbf803492e..fcbbb416261 100644 --- a/paddle/fluid/platform/external_error.proto +++ b/paddle/fluid/platform/external_error.proto @@ -25,6 +25,7 @@ enum ApiType { CUSOLVER = 4; NCCL = 5; CUFFT = 6; + CU = 7; } message MessageDesc { diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index c624ba94b74..9dc6254234a 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -26,6 +26,11 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cudnn.h" #endif #include "paddle/fluid/memory/malloc.h" +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 +#include "paddle/fluid/platform/dynload/cuda_driver.h" +#endif +#endif #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/lock_guard_ptr.h" #include "paddle/fluid/platform/macros.h" @@ -641,6 +646,30 @@ class RecordedCudaMallocHelper { uint64_t LimitSize() const { return limit_size_; } +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 + CUresult MemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags) { // NOLINT + auto result = + paddle::platform::dynload::cuMemCreate(handle, size, prop, flags); + if (result == CUDA_SUCCESS) { + cur_size_.fetch_add(size); + } + return result; + } + + CUresult MemRelease(CUmemGenericAllocationHandle handle, size_t size) { + auto result = paddle::platform::dynload::cuMemRelease(handle); + if (result == CUDA_SUCCESS) { + cur_size_.fetch_sub(size); + } + return result; + } + +#endif +#endif + private: const int dev_id_; const uint64_t limit_size_; @@ -664,6 +693,22 @@ void RecordedCudaFree(void *p, size_t size, int dev_id) { return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size); } +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 +CUresult RecordedCuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags, int dev_id) { // NOLINT + return RecordedCudaMallocHelper::Instance(dev_id)->MemCreate(handle, size, + prop, flags); +} + +CUresult RecordedCuMemRelease(CUmemGenericAllocationHandle handle, size_t size, + int dev_id) { + return RecordedCudaMallocHelper::Instance(dev_id)->MemRelease(handle, size); +} +#endif +#endif + bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, size_t *actual_total, int dev_id) { return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo( diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 401873dcd77..93e787fcf36 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -131,6 +131,20 @@ gpuError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id); //! CudaFree with recorded info void RecordedCudaFree(void *p, size_t size, int dev_id); +#ifdef PADDLE_WITH_CUDA +#if CUDA_VERSION >= 10020 + +//! cuMemCreate with recorded info +CUresult RecordedCuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags, int dev_id); // NOLINT + +//! cuMemRelease with recorded info +CUresult RecordedCuMemRelease(CUmemGenericAllocationHandle handle, size_t size, + int dev_id); +#endif +#endif + //! Get available and total gpu memory with considering limitation bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, size_t *actual_total, int dev_id); diff --git a/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py index 8b6d37882ba..a73ebd73e49 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py @@ -43,7 +43,7 @@ def _get_softmax_upper(x, fp16=True): class TestSoftmaxMaskFuseOp(OpTest): def setUp(self): self.op_type = "fused_softmax_mask_upper_triangle" - x = np.random.random((1, 1, 32, 32)).astype("float16") + x = np.random.random((1, 4, 32, 32)).astype("float16") self.inputs = {'X': x} rst = _get_softmax_upper(x) self.outputs = {'Out': rst} @@ -60,7 +60,7 @@ class TestSoftmaxMaskFuseOp(OpTest): class TestSoftmaxMaskFuseOp1(OpTest): def setUp(self): self.op_type = "fused_softmax_mask_upper_triangle" - x = np.random.random((1, 1, 32, 32)) + x = np.random.random((1, 4, 32, 32)) self.inputs = {'X': x} rst = _get_softmax_upper(x) self.outputs = {'Out': rst} @@ -90,10 +90,10 @@ class TestDropoutBiasFuseOp2(unittest.TestCase): for dtype in self.dtypes: with fluid.program_guard(fluid.Program(), fluid.Program()): input_x = fluid.data( - name="x", shape=[1, 1, 32, 32], dtype=dtype) + name="x", shape=[1, 4, 32, 32], dtype=dtype) rst = incubate.softmax_mask_fuse_upper_triangle(input_x) - x_in_np = np.random.random((1, 1, 32, 32)).astype(dtype) + x_in_np = np.random.random((1, 4, 32, 32)).astype(dtype) rst_np = _get_softmax_upper(x_in_np, dtype == 'float16') exe = fluid.Executor(fluid.CUDAPlace(0)) @@ -105,7 +105,7 @@ class TestDropoutBiasFuseOp2(unittest.TestCase): def test_dygraph(self): for dtype in self.dtypes: with fluid.dygraph.guard(fluid.CUDAPlace(0)): - x_in_np = np.random.random((1, 1, 32, 32)).astype(dtype) + x_in_np = np.random.random((1, 4, 32, 32)).astype(dtype) rst_np = _get_softmax_upper(x_in_np, dtype == 'float16') input_x = fluid.dygraph.to_variable(x_in_np) -- GitLab