未验证 提交 42910361 编写于 作者: F From00 提交者: GitHub

Support allocate CUDA managed memory (#39075)

* Rough implementation for experiment

* Support allocate cuda managed memory

* Fix CI error

* Modify UT

* Check whether support memory oversubscription

* Fix ROCM Compile error

* Fix ROCM Compile error

* Fix UT cuda_managed_memory_test

* Set UT timeout to 40

* Add UT OOMExceptionTest

* Set UT timeout to 50
上级 4d7ad277
......@@ -20,18 +20,29 @@ if (WITH_GPU)
nv_test(stream_safe_cuda_alloc_test
SRCS stream_safe_cuda_alloc_test.cu
DEPS malloc cuda_graph_with_memory_pool)
nv_test(cuda_managed_memory_test
SRCS cuda_managed_memory_test.cu
DEPS malloc gpu_info place)
if(WITH_TESTING AND TEST stream_safe_cuda_alloc_test)
set_tests_properties(stream_safe_cuda_alloc_test PROPERTIES
ENVIRONMENT "FLAGS_use_stream_safe_cuda_allocator=true;
FLAGS_allocator_strategy=auto_growth")
endif()
ENVIRONMENT "FLAGS_use_stream_safe_cuda_allocator=true;FLAGS_allocator_strategy=auto_growth")
endif()
endif()
if (WITH_ROCM)
hip_test(malloc_test
SRCS malloc_test.cu
DEPS device_context malloc)
hip_test(cuda_managed_memory_test
SRCS cuda_managed_memory_test.cu
DEPS malloc gpu_info place)
endif()
if(WITH_TESTING AND TEST cuda_managed_memory_test)
set_tests_properties(cuda_managed_memory_test PROPERTIES
ENVIRONMENT "FLAGS_use_cuda_managed_memory=true;FLAGS_allocator_strategy=auto_growth"
TIMEOUT 50)
endif()
if(WITH_GPU AND WITH_TESTING AND NOT "$ENV{CI_SKIP_CPP_TEST}" STREQUAL "ON")
......
......@@ -15,6 +15,7 @@ endif()
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
nv_library(cuda_managed_allocator SRCS cuda_managed_allocator.cc DEPS allocator cuda_device_guard gpu_info)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
nv_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator)
nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
......@@ -27,6 +28,7 @@ endif()
if (WITH_ROCM)
hip_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
hip_library(cuda_managed_allocator SRCS cuda_managed_allocator.cc DEPS allocator cuda_device_guard gpu_info)
hip_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
hip_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator)
hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
......@@ -42,7 +44,7 @@ endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
if (WITH_GPU OR WITH_ROCM)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator device_context)
set(AllocatorFacadeDeps gpu_info cuda_allocator cuda_managed_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator device_context)
if(CUDA_VERSION GREATER_EQUAL 10.2)
list(APPEND AllocatorFacadeDeps cuda_virtual_mem_allocator)
endif()
......
......@@ -28,6 +28,7 @@
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include <shared_mutex>
#include "paddle/fluid/memory/allocation/cuda_allocator.h"
#include "paddle/fluid/memory/allocation/cuda_managed_allocator.h"
#include "paddle/fluid/memory/allocation/pinned_allocator.h"
#include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h"
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
......@@ -80,6 +81,11 @@ PADDLE_DEFINE_EXPORTED_bool(use_virtual_memory_auto_growth, false,
PADDLE_DEFINE_EXPORTED_bool(use_stream_safe_cuda_allocator, false,
"Enable StreamSafeCUDAAllocator");
PADDLE_DEFINE_EXPORTED_bool(use_cuda_managed_memory, false,
"Whether to use CUDAManagedAllocator to allocate "
"managed memory, only available for auto_growth "
"strategy");
DECLARE_string(allocator_strategy);
namespace paddle {
......@@ -436,6 +442,37 @@ class AllocatorFacadePrivate {
std::make_shared<NaiveBestFitAllocator>(platform::CUDAPinnedPlace());
}
void InitNaiveBestFitCUDAAllocator(platform::CUDAPlace p) {
allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
}
// Create a new CUDAAllocator or CUDAManagedAllocator for the given device
std::shared_ptr<Allocator> CreateCUDAAllocator(platform::CUDAPlace p) {
if (FLAGS_use_cuda_managed_memory) {
PADDLE_ENFORCE_EQ(
strategy_, AllocatorStrategy::kAutoGrowth,
platform::errors::InvalidArgument(
"CUDA managed memory is only implemented for auto_growth "
"strategy, not support %s strategy.\n"
"Please use auto_growth strategy by command `export "
"FLAGS_allocator_strategy=\"auto_growth\"`, or disable managed "
"memory by command `export FLAGS_use_cuda_managed_memory=false`",
FLAGS_allocator_strategy));
if (!platform::IsGPUManagedMemorySupported(p.device)) {
PADDLE_THROW(platform::errors::Unavailable(
"Failed to create CUDAManagedAllocator on GPU %d.\n\n"
"You have enabled CUDA managed memory, but the gpu device does not "
"support allocating managed memory.\n"
"If you don't actually need to use managed memory, please disable "
"it with command `export FLAGS_use_cuda_managed_memory=false`.\n"
"Or you must use the gpu device that supports managed memory."));
}
return std::make_shared<CUDAManagedAllocator>(p);
}
return std::make_shared<CUDAAllocator>(p);
}
void InitStreamSafeCUDAAllocator(platform::CUDAPlace p, gpuStream_t stream) {
PADDLE_ENFORCE_EQ(
strategy_, AllocatorStrategy::kAutoGrowth,
......@@ -452,13 +489,9 @@ class AllocatorFacadePrivate {
}
}
void InitNaiveBestFitCUDAAllocator(platform::CUDAPlace p) {
allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
}
void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, gpuStream_t stream) {
#if defined(PADDLE_WITH_HIP)
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
auto cuda_allocator = CreateCUDAAllocator(p);
cuda_allocators_[p][stream] = std::make_shared<AutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize(), 0, allow_free_idle_chunk_);
#endif
......@@ -485,14 +518,14 @@ class AllocatorFacadePrivate {
std::make_shared<VirtualMemoryAutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize(), p);
} else {
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
auto cuda_allocator = CreateCUDAAllocator(p);
cuda_allocators_[p][stream] =
std::make_shared<AutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize(),
allow_free_idle_chunk_);
}
#else
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
auto cuda_allocator = CreateCUDAAllocator(p);
auto alignment = platform::GpuMinChunkSize();
bool need_addr_align = true;
// NOTE: sometimes, since cuda runtime can not be forked, calling any cuda
......@@ -535,7 +568,7 @@ class AllocatorFacadePrivate {
void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p,
bool allow_free_idle_chunk) {
#if defined(PADDLE_WITH_HIP)
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
auto cuda_allocator = CreateCUDAAllocator(p);
allocators_[p] = std::make_shared<AutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk);
#endif
......@@ -562,13 +595,13 @@ class AllocatorFacadePrivate {
std::make_shared<VirtualMemoryAutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize(), p);
} else {
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
auto cuda_allocator = CreateCUDAAllocator(p);
allocators_[p] = std::make_shared<AutoGrowthBestFitAllocator>(
cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk);
}
#else
auto cuda_allocator = std::make_shared<CUDAAllocator>(p);
auto cuda_allocator = CreateCUDAAllocator(p);
auto alignment = platform::GpuMinChunkSize();
bool need_addr_align = true;
// NOTE: sometimes, since cuda runtime can not be forked, calling any cuda
......@@ -690,7 +723,7 @@ class AllocatorFacadePrivate {
int device_count = platform::GetGPUDeviceCount();
for (int i = 0; i < device_count; ++i) {
platform::CUDAPlace p(i);
system_allocators_[p] = std::make_shared<CUDAAllocator>(p);
system_allocators_[p] = CreateCUDAAllocator(p);
}
#endif
#ifdef PADDLE_WITH_MLU
......
......@@ -67,16 +67,24 @@ pten::Allocation* CUDAAllocator::AllocateImpl(size_t size) {
limit_size, limit_size);
}
std::string managed_memory_msg;
if (platform::IsGPUManagedMemoryOversubscriptionSupported(place_.device)) {
managed_memory_msg = string::Sprintf(
"If the above ways do not solve the out of memory problem, you can try "
"to use CUDA managed memory. The command is `export "
"FLAGS_use_cuda_managed_memory=false`.");
}
PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted(
"\n\nOut of memory error on GPU %d. "
"Cannot allocate %s memory on GPU %d, %s memory has been allocated and "
"available memory is only %s.\n\n"
"Please check whether there is any other process using GPU %d.\n"
"1. If yes, please stop them, or start PaddlePaddle on another GPU.\n"
"2. If no, please decrease the batch size of your model. %s\n\n",
"2. If no, please decrease the batch size of your model. %s\n%s\n",
place_.device, string::HumanReadableSize(size), place_.device,
string::HumanReadableSize(allocated), string::HumanReadableSize(avail),
place_.device, err_msg));
place_.device, err_msg, managed_memory_msg));
}
} // namespace allocation
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/cuda_managed_allocator.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <string>
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
bool CUDAManagedAllocator::IsAllocThreadSafe() const { return true; }
void CUDAManagedAllocator::FreeImpl(pten::Allocation* allocation) {
PADDLE_ENFORCE_EQ(
allocation->place(), place_,
platform::errors::PermissionDenied(
"GPU memory is freed in incorrect device. This may be a bug"));
platform::RecordedGpuFree(allocation->ptr(), allocation->size(),
place_.device);
delete allocation;
}
pten::Allocation* CUDAManagedAllocator::AllocateImpl(size_t size) {
std::call_once(once_flag_, [this] { platform::SetDeviceId(place_.device); });
int dev_id = place_.device;
void* ptr;
auto result = platform::RecordedGpuMalloc(&ptr, size, dev_id,
/* malloc_managed_memory = */ true);
if (LIKELY(result == gpuSuccess)) {
return new Allocation(ptr, size, platform::Place(place_));
}
uint64_t limit_size = platform::RecordedGpuLimitSize(dev_id);
uint64_t malloc_size = platform::RecordedGpuMallocSize(dev_id);
bool is_limited =
platform::IsGpuMallocRecorded(dev_id) && malloc_size + size > limit_size;
std::string err_msg;
if (UNLIKELY(is_limited)) {
int64_t limit_size_mb = limit_size >> 20;
err_msg = string::Sprintf(
"Or set environment variable `FLAGS_gpu_memory_limit_mb` to a larger "
"value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the maximum "
"GPU memory usage is limited to %d MB.\n"
" The command is `export FLAGS_gpu_memory_limit_mb=xxx`.",
limit_size_mb, limit_size_mb);
}
PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted(
"\n\nOut of memory error on GPU %d. "
"Cannot allocate %s CUDA managed memory on GPU %d, %s memory has been "
"allocated.\n\n"
"Please check whether there is any other process using GPU %d.\n"
"1. If yes, please stop them, or start PaddlePaddle on another GPU.\n"
"2. If no, please decrease the batch size of your model. %s\n\n",
dev_id, string::HumanReadableSize(size), dev_id,
string::HumanReadableSize(malloc_size), dev_id, err_msg));
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
namespace allocation {
class CUDAManagedAllocator : public Allocator {
public:
explicit CUDAManagedAllocator(const platform::CUDAPlace& place)
: place_(place) {}
bool IsAllocThreadSafe() const override;
protected:
void FreeImpl(pten::Allocation* allocation) override;
pten::Allocation* AllocateImpl(size_t size) override;
private:
platform::CUDAPlace place_;
std::once_flag once_flag_;
};
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
__global__ void write_kernel(int* data, uint64_t n, uint64_t step) {
int thread_num = gridDim.x * blockDim.x;
int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
for (uint64_t i = thread_id; i * step < n; i += thread_num) {
*(data + i * step) = 1;
}
}
__global__ void sum_kernel(int* data, uint64_t n, uint64_t step, int* sum) {
int thread_num = gridDim.x * blockDim.x;
int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
for (uint64_t i = thread_id; i * step < n; i += thread_num) {
atomicAdd(sum, *(data + i * step));
}
}
TEST(ManagedMemoryTest, H2DTest) {
if (!platform::IsGPUManagedMemorySupported(0)) {
return;
}
uint64_t n_data = 1024;
uint64_t step = 1;
allocation::AllocationPtr allocation =
Alloc(platform::CUDAPlace(0), n_data * sizeof(int));
int* data = static_cast<int*>(allocation->ptr());
memset(data, 0, n_data * sizeof(int)); // located on host memory
write_kernel<<<1, 1024>>>(data, n_data, step); // trans to device memory
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
int sum = 0;
for (uint64_t i = 0; i < n_data; ++i) {
sum += *(data + i);
}
EXPECT_EQ(sum, n_data / step);
allocation = nullptr;
}
TEST(ManagedMemoryTest, D2HTest) {
if (!platform::IsGPUManagedMemorySupported(0)) {
return;
}
uint64_t n_data = 1024;
uint64_t step = 1;
AllocationPtr allocation =
Alloc(platform::CUDAPlace(0), n_data * sizeof(int));
int* data = static_cast<int*>(allocation->ptr());
write_kernel<<<1, 1024>>>(data, n_data, step); // located on device memory
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
memset(data, 0, n_data * sizeof(int)); // trans to host memory
int sum = 0;
for (uint64_t i = 0; i < n_data; ++i) {
sum += *(data + i);
}
EXPECT_EQ(sum, 0);
}
TEST(ManagedMemoryTest, OversubscribeGPUMemoryTest) {
if (!platform::IsGPUManagedMemoryOversubscriptionSupported(0)) {
return;
}
uint64_t available_mem = platform::GpuAvailableMemToAlloc();
uint64_t n_data = available_mem * 2 / sizeof(int) +
1; // requires more than 2 * available_mem bytes
uint64_t step = 1024;
AllocationPtr data_allocation =
Alloc(platform::CUDAPlace(0), n_data * sizeof(int));
AllocationPtr sum_allocation = Alloc(platform::CUDAPlace(0), sizeof(int));
int* data = static_cast<int*>(data_allocation->ptr());
int* sum = static_cast<int*>(sum_allocation->ptr());
(*sum) = 0;
write_kernel<<<5120, 1024>>>(data, n_data, step);
sum_kernel<<<5120, 1024>>>(data, n_data, step, sum);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
EXPECT_EQ(*sum, (n_data + step - 1) / step);
}
TEST(ManagedMemoryTest, OOMExceptionTest) {
EXPECT_THROW(Alloc(platform::CUDAPlace(0), size_t(1) << 60),
memory::allocation::BadAlloc);
}
} // namespace memory
} // namespace paddle
......@@ -167,18 +167,28 @@ class RecordedGpuMallocHelper {
* or cudaSuccess would be returned, and the cudaGetLastError() flag
* would be clear.
*/
gpuError_t Malloc(void **ptr, size_t size) {
gpuError_t Malloc(void **ptr, size_t size,
bool malloc_managed_memory = false) {
LockGuardPtr<std::mutex> lock(mtx_);
if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) {
return gpuErrorOutOfMemory;
}
CUDADeviceGuard guard(dev_id_);
gpuError_t result;
#ifdef PADDLE_WITH_HIP
auto result = hipMalloc(ptr, size);
if (UNLIKELY(malloc_managed_memory)) {
result = hipMallocManaged(ptr, size);
} else {
result = hipMalloc(ptr, size);
}
#else
CUDAGraphCaptureModeGuard capture_mode_guard;
auto result = cudaMalloc(ptr, size);
if (UNLIKELY(malloc_managed_memory)) {
result = cudaMallocManaged(ptr, size);
} else {
result = cudaMalloc(ptr, size);
}
#endif
if (result == gpuSuccess) {
cur_size_.fetch_add(size);
......@@ -318,8 +328,10 @@ std::once_flag RecordedGpuMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedGpuMallocHelper>>
RecordedGpuMallocHelper::instances_;
gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id) {
return RecordedGpuMallocHelper::Instance(dev_id)->Malloc(ptr, size);
gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id,
bool malloc_managed_memory) {
return RecordedGpuMallocHelper::Instance(dev_id)->Malloc(
ptr, size, malloc_managed_memory);
}
void RecordedGpuFree(void *p, size_t size, int dev_id) {
......@@ -352,6 +364,10 @@ uint64_t RecordedGpuMallocSize(int dev_id) {
return RecordedGpuMallocHelper::Instance(dev_id)->RecordedSize();
}
uint64_t RecordedGpuLimitSize(int dev_id) {
return RecordedGpuMallocHelper::Instance(dev_id)->LimitSize();
}
bool IsGpuMallocRecorded(int dev_id) {
return RecordedGpuMallocHelper::Instance(dev_id)->NeedRecord();
}
......@@ -363,6 +379,15 @@ void EmptyCache(void) {
}
}
bool IsGPUManagedMemorySupported(int dev_id) {
return pten::backends::gpu::IsGPUManagedMemorySupported(dev_id);
}
bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) {
return pten::backends::gpu::IsGPUManagedMemoryOversubscriptionSupported(
dev_id);
}
void *GetGpuBasePtr(void *ptr, int dev_id) {
return RecordedGpuMallocHelper::Instance(dev_id)->GetBasePtr(ptr);
}
......
......@@ -114,7 +114,8 @@ void GpuDestroyStream(gpuStream_t stream);
void GpuDeviceSync();
//! CudaMalloc with recorded info
gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id);
gpuError_t RecordedGpuMalloc(void **ptr, size_t size, int dev_id,
bool malloc_managed_memory = false);
//! CudaFree with recorded info
void RecordedGpuFree(void *p, size_t size, int dev_id);
......@@ -141,11 +142,17 @@ bool RecordedGpuMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
//! Get recorded cudaMalloc size. If record is disabled, return 0.
uint64_t RecordedGpuMallocSize(int dev_id);
uint64_t RecordedGpuLimitSize(int dev_id);
bool IsGpuMallocRecorded(int dev_id);
//! Empty idle cached memory held by the allocator.
void EmptyCache(void);
bool IsGPUManagedMemorySupported(int dev_id);
bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id);
//! Get the primitive pointer return from cudaMalloc, just implemented with
//! testing, do not use for release
void *GetGpuBasePtr(void *ptr, int dev_id);
......
......@@ -290,6 +290,43 @@ void GpuDeviceSync() { PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); }
gpuError_t GpuGetLastError() { return cudaGetLastError(); }
// See
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#um-requirements
// for more detail about managed memory requirements
bool IsGPUManagedMemorySupported(int dev_id) {
PADDLE_ENFORCE_LT(dev_id,
GetGPUDeviceCount(),
paddle::platform::errors::InvalidArgument(
"Device id must be less than GPU count, "
"but received id is: %d. GPU count is: %d.",
dev_id,
GetGPUDeviceCount()));
#if defined(__linux__) || defined(_WIN32)
int ManagedMemoryAttr;
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute(
&ManagedMemoryAttr, cudaDevAttrManagedMemory, dev_id));
return ManagedMemoryAttr != 0;
#else
return false;
#endif
}
bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) {
PADDLE_ENFORCE_LT(dev_id,
GetGPUDeviceCount(),
paddle::platform::errors::InvalidArgument(
"Device id must be less than GPU count, "
"but received id is: %d. GPU count is: %d.",
dev_id,
GetGPUDeviceCount()));
#ifdef __linux__
return IsGPUManagedMemorySupported(dev_id) &&
GetGPUComputeCapability(dev_id) >= 60;
#else
return false;
#endif
}
} // namespace gpu
} // namespace backends
} // namespace pten
......@@ -104,6 +104,10 @@ void GpuDeviceSync();
gpuError_t GpuGetLastError();
bool IsGPUManagedMemorySupported(int dev_id);
bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id);
class GPUDeviceGuard {
public:
explicit inline GPUDeviceGuard(int dev_id) {
......
......@@ -292,6 +292,40 @@ void GpuDeviceSync() { PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); }
gpuError_t GpuGetLastError() { return hipGetLastError(); }
bool IsGPUManagedMemorySupported(int dev_id) {
PADDLE_ENFORCE_LT(dev_id,
GetGPUDeviceCount(),
paddle::platform::errors::InvalidArgument(
"Device id must be less than GPU count, "
"but received id is: %d. GPU count is: %d.",
dev_id,
GetGPUDeviceCount()));
#if defined(__linux__) || defined(_WIN32)
int ManagedMemoryAttr;
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute(
&ManagedMemoryAttr, hipDeviceAttributeManagedMemory, dev_id));
return ManagedMemoryAttr != 0;
#else
return false;
#endif
}
bool IsGPUManagedMemoryOversubscriptionSupported(int dev_id) {
PADDLE_ENFORCE_LT(dev_id,
GetGPUDeviceCount(),
paddle::platform::errors::InvalidArgument(
"Device id must be less than GPU count, "
"but received id is: %d. GPU count is: %d.",
dev_id,
GetGPUDeviceCount()));
#ifdef __linux__
return IsGPUManagedMemorySupported(dev_id) &&
GetGPUComputeCapability(dev_id) >= 60;
#else
return false;
#endif
}
} // namespace gpu
} // namespace backends
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册