From 42910361d2997e60e0c5c14edd7418f556d97272 Mon Sep 17 00:00:00 2001 From: From00 Date: Tue, 8 Feb 2022 12:05:26 +0800 Subject: [PATCH] Support allocate CUDA managed memory (#39075) * Rough implementation for experiment * Support allocate cuda managed memory * Fix CI error * Modify UT * Check whether support memory oversubscription * Fix ROCM Compile error * Fix ROCM Compile error * Fix UT cuda_managed_memory_test * Set UT timeout to 40 * Add UT OOMExceptionTest * Set UT timeout to 50 --- paddle/fluid/memory/CMakeLists.txt | 17 ++- paddle/fluid/memory/allocation/CMakeLists.txt | 4 +- .../memory/allocation/allocator_facade.cc | 55 +++++-- .../fluid/memory/allocation/cuda_allocator.cc | 12 +- .../allocation/cuda_managed_allocator.cc | 86 +++++++++++ .../allocation/cuda_managed_allocator.h | 41 ++++++ .../fluid/memory/cuda_managed_memory_test.cu | 136 ++++++++++++++++++ paddle/fluid/platform/device/gpu/gpu_info.cc | 35 ++++- paddle/fluid/platform/device/gpu/gpu_info.h | 9 +- paddle/pten/backends/gpu/cuda/cuda_info.cc | 37 +++++ paddle/pten/backends/gpu/gpu_info.h | 4 + paddle/pten/backends/gpu/rocm/rocm_info.cc | 34 +++++ 12 files changed, 447 insertions(+), 23 deletions(-) create mode 100644 paddle/fluid/memory/allocation/cuda_managed_allocator.cc create mode 100644 paddle/fluid/memory/allocation/cuda_managed_allocator.h create mode 100644 paddle/fluid/memory/cuda_managed_memory_test.cu diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 023b40518e..4492615d23 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -20,18 +20,29 @@ if (WITH_GPU) nv_test(stream_safe_cuda_alloc_test SRCS stream_safe_cuda_alloc_test.cu DEPS malloc cuda_graph_with_memory_pool) + nv_test(cuda_managed_memory_test + SRCS cuda_managed_memory_test.cu + DEPS malloc gpu_info place) if(WITH_TESTING AND TEST stream_safe_cuda_alloc_test) set_tests_properties(stream_safe_cuda_alloc_test PROPERTIES - ENVIRONMENT "FLAGS_use_stream_safe_cuda_allocator=true; - FLAGS_allocator_strategy=auto_growth") - endif() + ENVIRONMENT "FLAGS_use_stream_safe_cuda_allocator=true;FLAGS_allocator_strategy=auto_growth") + endif() endif() if (WITH_ROCM) hip_test(malloc_test SRCS malloc_test.cu DEPS device_context malloc) + hip_test(cuda_managed_memory_test + SRCS cuda_managed_memory_test.cu + DEPS malloc gpu_info place) +endif() + +if(WITH_TESTING AND TEST cuda_managed_memory_test) +set_tests_properties(cuda_managed_memory_test PROPERTIES + ENVIRONMENT "FLAGS_use_cuda_managed_memory=true;FLAGS_allocator_strategy=auto_growth" + TIMEOUT 50) endif() if(WITH_GPU AND WITH_TESTING AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON") diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index c0d1934a70..b899ddbcd5 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -15,6 +15,7 @@ endif() if (WITH_GPU) nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) + nv_library(cuda_managed_allocator SRCS cuda_managed_allocator.cc DEPS allocator cuda_device_guard gpu_info) nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) nv_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator) nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) @@ -27,6 +28,7 @@ endif() if (WITH_ROCM) hip_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard) + hip_library(cuda_managed_allocator SRCS cuda_managed_allocator.cc DEPS allocator cuda_device_guard gpu_info) hip_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator) hip_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator) hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator) @@ -42,7 +44,7 @@ endif() cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator) if (WITH_GPU OR WITH_ROCM) - set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator device_context) + set(AllocatorFacadeDeps gpu_info cuda_allocator cuda_managed_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator device_context) if(CUDA_VERSION GREATER_EQUAL 10.2) list(APPEND AllocatorFacadeDeps cuda_virtual_mem_allocator) endif() diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 0f725a454c..f2bfaccd1d 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -28,6 +28,7 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include #include "paddle/fluid/memory/allocation/cuda_allocator.h" +#include "paddle/fluid/memory/allocation/cuda_managed_allocator.h" #include "paddle/fluid/memory/allocation/pinned_allocator.h" #include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h" #include "paddle/fluid/memory/allocation/thread_local_allocator.h" @@ -80,6 +81,11 @@ PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false, PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, false, "Enable StreamSafeCUDAAllocator"); +PADDLE_DEFINE_EXPORTED_bool(use_cuda_managed_memory, false, + "Whether to use CUDAManagedAllocator to allocate " + "managed memory, only available for auto_growth " + "strategy"); + DECLARE_string(allocator_strategy); namespace paddle { @@ -436,6 +442,37 @@ class AllocatorFacadePrivate { std::make_shared(platform::CUDAPinnedPlace()); } + void InitNaiveBestFitCUDAAllocator(platform::CUDAPlace p) { + allocators_[p] = std::make_shared(p); + } + + // Create a new CUDAAllocator or CUDAManagedAllocator for the given device + std::shared_ptr CreateCUDAAllocator(platform::CUDAPlace p) { + if (FLAGS_use_cuda_managed_memory) { + PADDLE_ENFORCE_EQ( + strategy_, AllocatorStrategy::kAutoGrowth, + platform::errors::InvalidArgument( + "CUDA managed memory is only implemented for auto_growth " + "strategy, not support %s strategy.\n" + "Please use auto_growth strategy by command `export " + "FLAGS_allocator_strategy=\"auto_growth\"`, or disable managed " + "memory by command `export FLAGS_use_cuda_managed_memory=false`", + FLAGS_allocator_strategy)); + + if (!platform::IsGPUManagedMemorySupported(p.device)) { + PADDLE_THROW(platform::errors::Unavailable( + "Failed to create CUDAManagedAllocator on GPU %d.\n\n" + "You have enabled CUDA managed memory, but the gpu device does not " + "support allocating managed memory.\n" + "If you don't actually need to use managed memory, please disable " + "it with command `export FLAGS_use_cuda_managed_memory=false`.\n" + "Or you must use the gpu device that supports managed memory.")); + } + return std::make_shared(p); + } + return std::make_shared(p); + } + void InitStreamSafeCUDAAllocator(platform::CUDAPlace p, gpuStream_t stream) { PADDLE_ENFORCE_EQ( strategy_, AllocatorStrategy::kAutoGrowth, @@ -452,13 +489,9 @@ class AllocatorFacadePrivate { } } - void InitNaiveBestFitCUDAAllocator(platform::CUDAPlace p) { - allocators_[p] = std::make_shared(p); - } - void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, gpuStream_t stream) { #if defined(PADDLE_WITH_HIP) - auto cuda_allocator = std::make_shared(p); + auto cuda_allocator = CreateCUDAAllocator(p); cuda_allocators_[p][stream] = std::make_shared( cuda_allocator, platform::GpuMinChunkSize(), 0, allow_free_idle_chunk_); #endif @@ -485,14 +518,14 @@ class AllocatorFacadePrivate { std::make_shared( cuda_allocator, platform::GpuMinChunkSize(), p); } else { - auto cuda_allocator = std::make_shared(p); + auto cuda_allocator = CreateCUDAAllocator(p); cuda_allocators_[p][stream] = std::make_shared( cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk_); } #else - auto cuda_allocator = std::make_shared(p); + auto cuda_allocator = CreateCUDAAllocator(p); auto alignment = platform::GpuMinChunkSize(); bool need_addr_align = true; // NOTE: sometimes, since cuda runtime can not be forked, calling any cuda @@ -535,7 +568,7 @@ class AllocatorFacadePrivate { void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, bool allow_free_idle_chunk) { #if defined(PADDLE_WITH_HIP) - auto cuda_allocator = std::make_shared(p); + auto cuda_allocator = CreateCUDAAllocator(p); allocators_[p] = std::make_shared( cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); #endif @@ -562,13 +595,13 @@ class AllocatorFacadePrivate { std::make_shared( cuda_allocator, platform::GpuMinChunkSize(), p); } else { - auto cuda_allocator = std::make_shared(p); + auto cuda_allocator = CreateCUDAAllocator(p); allocators_[p] = std::make_shared( cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); } #else - auto cuda_allocator = std::make_shared(p); + auto cuda_allocator = CreateCUDAAllocator(p); auto alignment = platform::GpuMinChunkSize(); bool need_addr_align = true; // NOTE: sometimes, since cuda runtime can not be forked, calling any cuda @@ -690,7 +723,7 @@ class AllocatorFacadePrivate { int device_count = platform::GetGPUDeviceCount(); for (int i = 0; i < device_count; ++i) { platform::CUDAPlace p(i); - system_allocators_[p] = std::make_shared(p); + system_allocators_[p] = CreateCUDAAllocator(p); } #endif #ifdef PADDLE_WITH_MLU diff --git a/paddle/fluid/memory/allocation/cuda_allocator.cc b/paddle/fluid/memory/allocation/cuda_allocator.cc index 6000e636dd..99a8efe05b 100644 --- a/paddle/fluid/memory/allocation/cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_allocator.cc @@ -67,16 +67,24 @@ pten::Allocation* CUDAAllocator::AllocateImpl(size_t size) { limit_size, limit_size); } + std::string managed_memory_msg; + if (platform::IsGPUManagedMemoryOversubscriptionSupported(place_.device)) { + managed_memory_msg = string::Sprintf( + "If the above ways do not solve the out of memory problem, you can try " + "to use CUDA managed memory. The command is `export " + "FLAGS_use_cuda_managed_memory=false`."); + } + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( "\n\nOut of memory error on GPU %d. " "Cannot allocate %s memory on GPU %d, %s memory has been allocated and " "available memory is only %s.\n\n" "Please check whether there is any other process using GPU %d.\n" "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" - "2. If no, please decrease the batch size of your model. %s\n\n", + "2. If no, please decrease the batch size of your model. %s\n%s\n", place_.device, string::HumanReadableSize(size), place_.device, string::HumanReadableSize(allocated), string::HumanReadableSize(avail), - place_.device, err_msg)); + place_.device, err_msg, managed_memory_msg)); } } // namespace allocation diff --git a/paddle/fluid/memory/allocation/cuda_managed_allocator.cc b/paddle/fluid/memory/allocation/cuda_managed_allocator.cc new file mode 100644 index 0000000000..000b5d2e37 --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_managed_allocator.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/memory/allocation/cuda_managed_allocator.h" + +#ifdef PADDLE_WITH_CUDA +#include +#include +#endif + +#ifdef PADDLE_WITH_HIP +#include +#endif + +#include +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace memory { +namespace allocation { +bool CUDAManagedAllocator::IsAllocThreadSafe() const { return true; } + +void CUDAManagedAllocator::FreeImpl(pten::Allocation* allocation) { + PADDLE_ENFORCE_EQ( + allocation->place(), place_, + platform::errors::PermissionDenied( + "GPU memory is freed in incorrect device. This may be a bug")); + platform::RecordedGpuFree(allocation->ptr(), allocation->size(), + place_.device); + delete allocation; +} + +pten::Allocation* CUDAManagedAllocator::AllocateImpl(size_t size) { + std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); }); + + int dev_id = place_.device; + void* ptr; + auto result = platform::RecordedGpuMalloc(&ptr, size, dev_id, + /* malloc_managed_memory = */ true); + if (LIKELY(result == gpuSuccess)) { + return new Allocation(ptr, size, platform::Place(place_)); + } + + uint64_t limit_size = platform::RecordedGpuLimitSize(dev_id); + uint64_t malloc_size = platform::RecordedGpuMallocSize(dev_id); + bool is_limited = + platform::IsGpuMallocRecorded(dev_id) && malloc_size + size > limit_size; + + std::string err_msg; + if (UNLIKELY(is_limited)) { + int64_t limit_size_mb = limit_size >> 20; + err_msg = string::Sprintf( + "Or set environment variable `FLAGS_gpu_memory_limit_mb` to a larger " + "value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the maximum " + "GPU memory usage is limited to %d MB.\n" + " The command is `export FLAGS_gpu_memory_limit_mb=xxx`.", + limit_size_mb, limit_size_mb); + } + + PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted( + "\n\nOut of memory error on GPU %d. " + "Cannot allocate %s CUDA managed memory on GPU %d, %s memory has been " + "allocated.\n\n" + "Please check whether there is any other process using GPU %d.\n" + "1. If yes, please stop them, or start PaddlePaddle on another GPU.\n" + "2. If no, please decrease the batch size of your model. %s\n\n", + dev_id, string::HumanReadableSize(size), dev_id, + string::HumanReadableSize(malloc_size), dev_id, err_msg)); +} + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/allocation/cuda_managed_allocator.h b/paddle/fluid/memory/allocation/cuda_managed_allocator.h new file mode 100644 index 0000000000..16279eea6a --- /dev/null +++ b/paddle/fluid/memory/allocation/cuda_managed_allocator.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class CUDAManagedAllocator : public Allocator { + public: + explicit CUDAManagedAllocator(const platform::CUDAPlace& place) + : place_(place) {} + + bool IsAllocThreadSafe() const override; + + protected: + void FreeImpl(pten::Allocation* allocation) override; + pten::Allocation* AllocateImpl(size_t size) override; + + private: + platform::CUDAPlace place_; + std::once_flag once_flag_; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/memory/cuda_managed_memory_test.cu b/paddle/fluid/memory/cuda_managed_memory_test.cu new file mode 100644 index 0000000000..4243c5fa90 --- /dev/null +++ b/paddle/fluid/memory/cuda_managed_memory_test.cu @@ -0,0 +1,136 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUDA +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#endif + +#include "gtest/gtest.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace memory { + +__global__ void write_kernel(int* data, uint64_t n, uint64_t step) { + int thread_num = gridDim.x * blockDim.x; + int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + for (uint64_t i = thread_id; i * step < n; i += thread_num) { + *(data + i * step) = 1; + } +} + +__global__ void sum_kernel(int* data, uint64_t n, uint64_t step, int* sum) { + int thread_num = gridDim.x * blockDim.x; + int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + for (uint64_t i = thread_id; i * step < n; i += thread_num) { + atomicAdd(sum, *(data + i * step)); + } +} + +TEST(ManagedMemoryTest, H2DTest) { + if (!platform::IsGPUManagedMemorySupported(0)) { + return; + } + + uint64_t n_data = 1024; + uint64_t step = 1; + allocation::AllocationPtr allocation = + Alloc(platform::CUDAPlace(0), n_data * sizeof(int)); + int* data = static_cast(allocation->ptr()); + + memset(data, 0, n_data * sizeof(int)); // located on host memory + write_kernel<<<1, 1024>>>(data, n_data, step); // trans to device memory + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + + int sum = 0; + for (uint64_t i = 0; i < n_data; ++i) { + sum += *(data + i); + } + EXPECT_EQ(sum, n_data / step); + allocation = nullptr; +} + +TEST(ManagedMemoryTest, D2HTest) { + if (!platform::IsGPUManagedMemorySupported(0)) { + return; + } + + uint64_t n_data = 1024; + uint64_t step = 1; + AllocationPtr allocation = + Alloc(platform::CUDAPlace(0), n_data * sizeof(int)); + int* data = static_cast(allocation->ptr()); + + write_kernel<<<1, 1024>>>(data, n_data, step); // located on device memory + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + + memset(data, 0, n_data * sizeof(int)); // trans to host memory + + int sum = 0; + for (uint64_t i = 0; i < n_data; ++i) { + sum += *(data + i); + } + EXPECT_EQ(sum, 0); +} + +TEST(ManagedMemoryTest, OversubscribeGPUMemoryTest) { + if (!platform::IsGPUManagedMemoryOversubscriptionSupported(0)) { + return; + } + + uint64_t available_mem = platform::GpuAvailableMemToAlloc(); + uint64_t n_data = available_mem * 2 / sizeof(int) + + 1; // requires more than 2 * available_mem bytes + uint64_t step = 1024; + AllocationPtr data_allocation = + Alloc(platform::CUDAPlace(0), n_data * sizeof(int)); + AllocationPtr sum_allocation = Alloc(platform::CUDAPlace(0), sizeof(int)); + int* data = static_cast(data_allocation->ptr()); + int* sum = static_cast(sum_allocation->ptr()); + (*sum) = 0; + + write_kernel<<<5120, 1024>>>(data, n_data, step); + sum_kernel<<<5120, 1024>>>(data, n_data, step, sum); + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + + EXPECT_EQ(*sum, (n_data + step - 1) / step); +} + +TEST(ManagedMemoryTest, OOMExceptionTest) { + EXPECT_THROW(Alloc(platform::CUDAPlace(0), size_t(1) << 60), + memory::allocation::BadAlloc); +} + +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index 59fb26e696..3957fe3c52 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -167,18 +167,28 @@ class RecordedGpuMallocHelper { * or cudaSuccess would be returned, and the cudaGetLastError() flag * would be clear. */ - gpuError_t Malloc(void **ptr, size_t size) { + gpuError_t Malloc(void **ptr, size_t size, + bool malloc_managed_memory = false) { LockGuardPtr lock(mtx_); if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) { return gpuErrorOutOfMemory; } CUDADeviceGuard guard(dev_id_); + gpuError_t result; #ifdef PADDLE_WITH_HIP - auto result = hipMalloc(ptr, size); + if (UNLIKELY(malloc_managed_memory)) { + result = hipMallocManaged(ptr, size); + } else { + result = hipMalloc(ptr, size); + } #else CUDAGraphCaptureModeGuard capture_mode_guard; - auto result = cudaMalloc(ptr, size); + if (UNLIKELY(malloc_managed_memory)) { + result = cudaMallocManaged(ptr, size); + } else { + result = cudaMalloc(ptr, size); + } #endif if (result == gpuSuccess) { cur_size_.fetch_add(size); @@ -318,8 +328,10 @@ std::once_flag RecordedGpuMallocHelper::once_flag_; std::vector> RecordedGpuMallocHelper::instances_; -gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id) { - return RecordedGpuMallocHelper::Instance(dev_id)->Malloc(ptr, size); +gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id, + bool malloc_managed_memory) { + return RecordedGpuMallocHelper::Instance(dev_id)->Malloc( + ptr, size, malloc_managed_memory); } void RecordedGpuFree(void *p, size_t size, int dev_id) { @@ -352,6 +364,10 @@ uint64_t RecordedGpuMallocSize(int dev_id) { return RecordedGpuMallocHelper::Instance(dev_id)->RecordedSize(); } +uint64_t RecordedGpuLimitSize(int dev_id) { + return RecordedGpuMallocHelper::Instance(dev_id)->LimitSize(); +} + bool IsGpuMallocRecorded(int dev_id) { return RecordedGpuMallocHelper::Instance(dev_id)->NeedRecord(); } @@ -363,6 +379,15 @@ void EmptyCache(void) { } } +bool IsGPUManagedMemorySupported(int dev_id) { + return pten::backends::gpu::IsGPUManagedMemorySupported(dev_id); +} + +bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) { + return pten::backends::gpu::IsGPUManagedMemoryOversubscriptionSupported( + dev_id); +} + void *GetGpuBasePtr(void *ptr, int dev_id) { return RecordedGpuMallocHelper::Instance(dev_id)->GetBasePtr(ptr); } diff --git a/paddle/fluid/platform/device/gpu/gpu_info.h b/paddle/fluid/platform/device/gpu/gpu_info.h index f6fb2ad8ce..94b47cca94 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.h +++ b/paddle/fluid/platform/device/gpu/gpu_info.h @@ -114,7 +114,8 @@ void GpuDestroyStream(gpuStream_t stream); void GpuDeviceSync(); //! CudaMalloc with recorded info -gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id); +gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id, + bool malloc_managed_memory = false); //! CudaFree with recorded info void RecordedGpuFree(void *p, size_t size, int dev_id); @@ -141,11 +142,17 @@ bool RecordedGpuMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail, //! Get recorded cudaMalloc size. If record is disabled, return 0. uint64_t RecordedGpuMallocSize(int dev_id); +uint64_t RecordedGpuLimitSize(int dev_id); + bool IsGpuMallocRecorded(int dev_id); //! Empty idle cached memory held by the allocator. void EmptyCache(void); +bool IsGPUManagedMemorySupported(int dev_id); + +bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id); + //! Get the primitive pointer return from cudaMalloc, just implemented with //! testing, do not use for release void *GetGpuBasePtr(void *ptr, int dev_id); diff --git a/paddle/pten/backends/gpu/cuda/cuda_info.cc b/paddle/pten/backends/gpu/cuda/cuda_info.cc index 55766facac..de28f7f344 100644 --- a/paddle/pten/backends/gpu/cuda/cuda_info.cc +++ b/paddle/pten/backends/gpu/cuda/cuda_info.cc @@ -290,6 +290,43 @@ void GpuDeviceSync() { PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); } gpuError_t GpuGetLastError() { return cudaGetLastError(); } +// See +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-requirements +// for more detail about managed memory requirements +bool IsGPUManagedMemorySupported(int dev_id) { + PADDLE_ENFORCE_LT(dev_id, + GetGPUDeviceCount(), + paddle::platform::errors::InvalidArgument( + "Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); +#if defined(__linux__) || defined(_WIN32) + int ManagedMemoryAttr; + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &ManagedMemoryAttr, cudaDevAttrManagedMemory, dev_id)); + return ManagedMemoryAttr != 0; +#else + return false; +#endif +} + +bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) { + PADDLE_ENFORCE_LT(dev_id, + GetGPUDeviceCount(), + paddle::platform::errors::InvalidArgument( + "Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); +#ifdef __linux__ + return IsGPUManagedMemorySupported(dev_id) && + GetGPUComputeCapability(dev_id) >= 60; +#else + return false; +#endif +} + } // namespace gpu } // namespace backends } // namespace pten diff --git a/paddle/pten/backends/gpu/gpu_info.h b/paddle/pten/backends/gpu/gpu_info.h index 59add6166d..5f94d1d836 100644 --- a/paddle/pten/backends/gpu/gpu_info.h +++ b/paddle/pten/backends/gpu/gpu_info.h @@ -104,6 +104,10 @@ void GpuDeviceSync(); gpuError_t GpuGetLastError(); +bool IsGPUManagedMemorySupported(int dev_id); + +bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id); + class GPUDeviceGuard { public: explicit inline GPUDeviceGuard(int dev_id) { diff --git a/paddle/pten/backends/gpu/rocm/rocm_info.cc b/paddle/pten/backends/gpu/rocm/rocm_info.cc index 095acdf076..390d5c4ad2 100644 --- a/paddle/pten/backends/gpu/rocm/rocm_info.cc +++ b/paddle/pten/backends/gpu/rocm/rocm_info.cc @@ -292,6 +292,40 @@ void GpuDeviceSync() { PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); } gpuError_t GpuGetLastError() { return hipGetLastError(); } +bool IsGPUManagedMemorySupported(int dev_id) { + PADDLE_ENFORCE_LT(dev_id, + GetGPUDeviceCount(), + paddle::platform::errors::InvalidArgument( + "Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); +#if defined(__linux__) || defined(_WIN32) + int ManagedMemoryAttr; + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute( + &ManagedMemoryAttr, hipDeviceAttributeManagedMemory, dev_id)); + return ManagedMemoryAttr != 0; +#else + return false; +#endif +} + +bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) { + PADDLE_ENFORCE_LT(dev_id, + GetGPUDeviceCount(), + paddle::platform::errors::InvalidArgument( + "Device id must be less than GPU count, " + "but received id is: %d. GPU count is: %d.", + dev_id, + GetGPUDeviceCount())); +#ifdef __linux__ + return IsGPUManagedMemorySupported(dev_id) && + GetGPUComputeCapability(dev_id) >= 60; +#else + return false; +#endif +} + } // namespace gpu } // namespace backends } // namespace pten -- GitLab