From 69875dc42caa26a7e32c183f87f26fa986abd8b0 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 1 Feb 2021 16:03:46 +0800 Subject: [PATCH] [ROCM] update fluid memory for rocm35 (part1), test=develop (#30758) --- paddle/fluid/memory/CMakeLists.txt | 6 ++++ .../memory/allocation/allocator_facade.cc | 14 ++++----- .../allocator_facade_abs_flags_test.cc | 6 ++-- .../allocator_facade_frac_flags_test.cc | 6 ++-- ...o_growth_best_fit_allocator_facade_test.cc | 8 ++--- .../fluid/memory/allocation/cuda_allocator.cc | 10 ++++++- .../cuda_device_context_allocator.h | 24 +++++++++++---- .../allocation/naive_best_fit_allocator.cc | 30 +++++++++++-------- .../naive_best_fit_allocator_test.cc | 2 +- .../memory/allocation/pinned_allocator.cc | 8 +++++ .../memory/allocation/retry_allocator_test.cc | 4 +-- paddle/fluid/memory/detail/CMakeLists.txt | 8 +++-- paddle/fluid/memory/detail/buddy_allocator.cc | 4 +-- .../memory/detail/buddy_allocator_test.cc | 22 ++++++++++++-- .../fluid/memory/detail/system_allocator.cc | 26 ++++++++++++---- paddle/fluid/memory/detail/system_allocator.h | 2 +- .../memory/detail/system_allocator_test.cc | 6 +++- paddle/fluid/memory/memcpy.cc | 10 +++---- paddle/fluid/memory/memcpy.h | 4 +-- 19 files changed, 138 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index 13626ae777..75b1bffca3 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -19,6 +19,12 @@ if (WITH_GPU) DEPS device_context malloc) endif() +if (WITH_ROCM) + hip_test(malloc_test + SRCS malloc_test.cu + DEPS device_context malloc) +endif() + #if (WITH_GPU) # nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place memory) #endif() diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index a124a56ef8..b901a3668d 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -31,7 +31,7 @@ #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/allocation/cuda_allocator.h" #include "paddle/fluid/memory/allocation/pinned_allocator.h" #include "paddle/fluid/memory/allocation/thread_local_allocator.h" @@ -70,7 +70,7 @@ class AllocatorFacadePrivate { InitNaiveBestFitXPUAllocator(platform::XPUPlace(dev_id)); } #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { InitNaiveBestFitCUDAAllocator(platform::CUDAPlace(dev_id)); @@ -87,7 +87,7 @@ class AllocatorFacadePrivate { InitNaiveBestFitXPUAllocator(platform::XPUPlace(dev_id)); } #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id)); @@ -104,7 +104,7 @@ class AllocatorFacadePrivate { InitNaiveBestFitXPUAllocator(platform::XPUPlace(dev_id)); } #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { InitThreadLocalCUDAAllocator(platform::CUDAPlace(dev_id)); @@ -152,7 +152,7 @@ class AllocatorFacadePrivate { system_allocators_[p] = std::make_shared(p); } #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) system_allocators_[platform::CUDAPinnedPlace()] = std::make_shared(); int device_count = platform::GetCUDADeviceCount(); @@ -168,7 +168,7 @@ class AllocatorFacadePrivate { std::make_shared(platform::CPUPlace()); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void InitNaiveBestFitCUDAPinnedAllocator() { allocators_[platform::CUDAPinnedPlace()] = std::make_shared(platform::CUDAPinnedPlace()); @@ -215,7 +215,7 @@ class AllocatorFacadePrivate { void InitZeroSizeAllocators() { std::vector places; places.emplace_back(platform::CPUPlace()); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) int device_count = platform::GetCUDADeviceCount(); for (int dev_id = 0; dev_id < device_count; ++dev_id) { places.emplace_back(platform::CUDAPlace(dev_id)); diff --git a/paddle/fluid/memory/allocation/allocator_facade_abs_flags_test.cc b/paddle/fluid/memory/allocation/allocator_facade_abs_flags_test.cc index 0029991116..d3f16ec628 100644 --- a/paddle/fluid/memory/allocation/allocator_facade_abs_flags_test.cc +++ b/paddle/fluid/memory/allocation/allocator_facade_abs_flags_test.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/memory/allocation/allocator_facade.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_cuda_pinned_memory_to_use); DECLARE_uint64(initial_gpu_memory_in_mb); @@ -45,7 +45,7 @@ void AllocateTestCases() { ASSERT_EQ(cpu_allocation->size(), size); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) { place = platform::CUDAPlace(0); size = 1024; @@ -81,7 +81,7 @@ void AllocateTestCases() { } TEST(Allocator, SpecifyGpuMemory) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // Set to 0.0 to test FLAGS_initial_gpu_memory_in_mb and // FLAGS_reallocate_gpu_memory_in_mb FLAGS_fraction_of_gpu_memory_to_use = 0.0; diff --git a/paddle/fluid/memory/allocation/allocator_facade_frac_flags_test.cc b/paddle/fluid/memory/allocation/allocator_facade_frac_flags_test.cc index 1e793d1617..85cd851a21 100644 --- a/paddle/fluid/memory/allocation/allocator_facade_frac_flags_test.cc +++ b/paddle/fluid/memory/allocation/allocator_facade_frac_flags_test.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/memory/allocation/allocator_facade.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_cuda_pinned_memory_to_use); DECLARE_uint64(initial_gpu_memory_in_mb); @@ -45,7 +45,7 @@ void AllocateTestCases() { ASSERT_EQ(cpu_allocation->size(), size); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) { place = platform::CUDAPlace(0); size = 1024; @@ -81,7 +81,7 @@ void AllocateTestCases() { } TEST(Allocator, Allocator) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) FLAGS_fraction_of_gpu_memory_to_use = 0.01; FLAGS_gpu_allocator_retry_time = 500; FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5; diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_facade_test.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_facade_test.cc index 1dcc820b26..11e599c4b5 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_facade_test.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_facade_test.cc @@ -22,7 +22,7 @@ #include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/platform/gpu_info.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) DECLARE_double(fraction_of_gpu_memory_to_use); DECLARE_double(fraction_of_cuda_pinned_memory_to_use); DECLARE_int64(gpu_allocator_retry_time); @@ -40,7 +40,7 @@ static inline size_t AlignTo(size_t size, size_t alignment) { } TEST(allocator, allocator) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) FLAGS_fraction_of_gpu_memory_to_use = 0.01; FLAGS_gpu_allocator_retry_time = 500; FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5; @@ -62,7 +62,7 @@ TEST(allocator, allocator) { ASSERT_EQ(cpu_allocation->size(), AlignedSize(size, 1024)); } -#ifdef PADDLE_WITH_CUDA +#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_HIP) { place = platform::CUDAPlace(0); size = 1024; @@ -101,7 +101,7 @@ TEST(allocator, allocator) { TEST(multithread_allocate, test_segfault) { FLAGS_allocator_strategy = "auto_growth"; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::mutex mtx; std::condition_variable cv; bool flag = false; diff --git a/paddle/fluid/memory/allocation/cuda_allocator.cc b/paddle/fluid/memory/allocation/cuda_allocator.cc index 39d33cf20b..c1b12f5c0e 100644 --- a/paddle/fluid/memory/allocation/cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_allocator.cc @@ -13,8 +13,16 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/cuda_allocator.h" + +#ifdef PADDLE_WITH_CUDA #include #include +#endif + +#ifdef PADDLE_WITH_HIP +#include +#endif + #include #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/enforce.h" @@ -39,7 +47,7 @@ Allocation* CUDAAllocator::AllocateImpl(size_t size) { void* ptr; auto result = platform::RecordedCudaMalloc(&ptr, size, place_.device); - if (LIKELY(result == cudaSuccess)) { + if (LIKELY(result == gpuSuccess)) { return new Allocation(ptr, size, platform::Place(place_)); } diff --git a/paddle/fluid/memory/allocation/cuda_device_context_allocator.h b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h index a8e458a999..3d6f1d7bcb 100644 --- a/paddle/fluid/memory/allocation/cuda_device_context_allocator.h +++ b/paddle/fluid/memory/allocation/cuda_device_context_allocator.h @@ -14,8 +14,6 @@ #pragma once -#include - #include #include #include @@ -79,17 +77,26 @@ class CUDADeviceContextAllocation : public Allocation { class CUDADeviceContextAllocator : public Allocator { public: explicit CUDADeviceContextAllocator(platform::CUDAPlace place, - cudaStream_t default_stream) + gpuStream_t default_stream) : place_(place), default_stream_(default_stream) { platform::CUDADeviceGuard guard(place_.device); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS( + hipEventCreateWithFlags(&event_, hipEventDisableTiming)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( cudaEventCreate(&event_, cudaEventDisableTiming)); +#endif } ~CUDADeviceContextAllocator() { if (event_) { platform::CUDADeviceGuard guard(place_.device); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(event_)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(event_)); +#endif } } @@ -102,10 +109,15 @@ class CUDADeviceContextAllocator : public Allocator { platform::CUDADeviceGuard guard(place_.device); auto allocation = new CUDADeviceContextAllocation(memory::Alloc(place_, size)); - // Wait for the event on stream +// Wait for the event on stream +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event_, default_stream_)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(default_stream_, event_, 0)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event_, default_stream_)); PADDLE_ENFORCE_CUDA_SUCCESS( cudaStreamWaitEvent(default_stream_, event_, 0)); +#endif return allocation; } @@ -113,8 +125,8 @@ class CUDADeviceContextAllocator : public Allocator { private: platform::CUDAPlace place_; - cudaEvent_t event_{nullptr}; - cudaStream_t default_stream_{nullptr}; + gpuEvent_t event_{nullptr}; + gpuStream_t default_stream_{nullptr}; }; /** diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index fcde4cbab4..9ae63e74f4 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -26,7 +26,7 @@ #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/split.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif #ifdef PADDLE_WITH_XPU @@ -216,7 +216,7 @@ size_t Used(const platform::XPUPlace &place) { #endif } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) class GPUBuddyAllocatorList { private: GPUBuddyAllocatorList() : devices_(platform::GetSelectedDevices()) { @@ -283,7 +283,7 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) { template <> size_t Used(const platform::CUDAPlace &place) { -#ifdef PADDLE_WITH_CUDA +#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_HIP) return GetGPUBuddyAllocator(place.device)->Used(); #else PADDLE_THROW(platform::errors::PermissionDenied( @@ -294,7 +294,7 @@ size_t Used(const platform::CUDAPlace &place) { template <> void *Alloc(const platform::CUDAPlace &place, size_t size) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto *buddy_allocator = GetGPUBuddyAllocator(place.device); auto *ptr = buddy_allocator->Alloc(size); if (ptr == nullptr) { @@ -311,7 +311,11 @@ void *Alloc(const platform::CUDAPlace &place, string::HumanReadableSize(Used(place)))); } else { if (FLAGS_init_allocated_mem) { +#ifdef PADDLE_WITH_HIP + hipMemset(ptr, 0xEF, size); +#else cudaMemset(ptr, 0xEF, size); +#endif } } return ptr; @@ -324,7 +328,7 @@ void *Alloc(const platform::CUDAPlace &place, template <> void Free(const platform::CUDAPlace &place, void *p, size_t size) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) GetGPUBuddyAllocator(place.device)->Free(p); #else PADDLE_THROW(platform::errors::PermissionDenied( @@ -334,7 +338,7 @@ void Free(const platform::CUDAPlace &place, void *p, template <> uint64_t Release(const platform::CUDAPlace &place) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return GetGPUBuddyAllocator(place.device)->Release(); #else PADDLE_THROW(platform::errors::PermissionDenied( @@ -342,7 +346,7 @@ uint64_t Release(const platform::CUDAPlace &place) { #endif } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) BuddyAllocator *GetCUDAPinnedBuddyAllocator() { static std::once_flag init_flag; static BuddyAllocator *ba = nullptr; @@ -360,7 +364,7 @@ BuddyAllocator *GetCUDAPinnedBuddyAllocator() { template <> size_t Used(const platform::CUDAPinnedPlace &place) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return GetCUDAPinnedBuddyAllocator()->Used(); #else PADDLE_THROW(platform::errors::PermissionDenied( @@ -371,7 +375,7 @@ size_t Used(const platform::CUDAPinnedPlace &place) { template <> void *Alloc(const platform::CUDAPinnedPlace &place, size_t size) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto *buddy_allocator = GetCUDAPinnedBuddyAllocator(); void *ptr = buddy_allocator->Alloc(size); @@ -392,7 +396,7 @@ void *Alloc(const platform::CUDAPinnedPlace &place, template <> void Free(const platform::CUDAPinnedPlace &place, void *p, size_t size) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) GetCUDAPinnedBuddyAllocator()->Free(p); #else PADDLE_THROW(platform::errors::PermissionDenied( @@ -403,7 +407,7 @@ void Free(const platform::CUDAPinnedPlace &place, template <> uint64_t Release( const platform::CUDAPinnedPlace &place) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return GetCUDAPinnedBuddyAllocator()->Release(); #else PADDLE_THROW(platform::errors::PermissionDenied( @@ -449,7 +453,7 @@ size_t Usage::operator()(const platform::CPUPlace &cpu) const { } size_t Usage::operator()(const platform::CUDAPlace &gpu) const { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return Used(gpu); #else PADDLE_THROW(platform::errors::PermissionDenied( @@ -458,7 +462,7 @@ size_t Usage::operator()(const platform::CUDAPlace &gpu) const { } size_t Usage::operator()(const platform::CUDAPinnedPlace &cuda_pinned) const { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return Used(cuda_pinned); #else PADDLE_THROW(platform::errors::PermissionDenied( diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc index 054c75b11f..b434b416fc 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator_test.cc @@ -41,7 +41,7 @@ TEST(NaiveBestFitAllocatorTest, CpuAlloc) { alloc.Release(platform::CPUPlace()); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(NaiveBestFitAllocatorTest, GpuAlloc) { NaiveBestFitAllocator alloc{platform::CUDAPlace(0)}; { diff --git a/paddle/fluid/memory/allocation/pinned_allocator.cc b/paddle/fluid/memory/allocation/pinned_allocator.cc index 42dd50af72..5aa0514432 100644 --- a/paddle/fluid/memory/allocation/pinned_allocator.cc +++ b/paddle/fluid/memory/allocation/pinned_allocator.cc @@ -19,12 +19,20 @@ namespace memory { namespace allocation { bool CPUPinnedAllocator::IsAllocThreadSafe() const { return true; } void CPUPinnedAllocator::FreeImpl(Allocation *allocation) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipHostFree(allocation->ptr())); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaFreeHost(allocation->ptr())); +#endif delete allocation; } Allocation *CPUPinnedAllocator::AllocateImpl(size_t size) { void *ptr; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipHostMalloc(&ptr, size, hipHostMallocPortable)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaHostAlloc(&ptr, size, cudaHostAllocPortable)); +#endif return new Allocation(ptr, size, platform::CUDAPinnedPlace()); } } // namespace allocation diff --git a/paddle/fluid/memory/allocation/retry_allocator_test.cc b/paddle/fluid/memory/allocation/retry_allocator_test.cc index 13b77c660c..7f95f9bcd5 100644 --- a/paddle/fluid/memory/allocation/retry_allocator_test.cc +++ b/paddle/fluid/memory/allocation/retry_allocator_test.cc @@ -26,7 +26,7 @@ #include "paddle/fluid/memory/allocation/best_fit_allocator.h" #include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/locked_allocator.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/memory/allocation/cuda_allocator.h" #endif @@ -127,7 +127,7 @@ TEST(RetryAllocator, RetryAllocatorLastAllocFailure) { } } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) { platform::CUDAPlace p(0); RetryAllocator allocator(std::make_shared(p), retry_ms); diff --git a/paddle/fluid/memory/detail/CMakeLists.txt b/paddle/fluid/memory/detail/CMakeLists.txt index 8f0988e871..fcae741db3 100644 --- a/paddle/fluid/memory/detail/CMakeLists.txt +++ b/paddle/fluid/memory/detail/CMakeLists.txt @@ -2,11 +2,13 @@ include(ExternalProject) cc_library(memory_block SRCS memory_block.cc memory_block_desc.cc meta_cache.cc DEPS place) -if(${WITH_GPU}) +if(WITH_GPU) nv_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info gpu_info place) -else(${WITH_GPU}) +elseif(WITH_ROCM) + hip_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info gpu_info place) +else() cc_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info place) -endif(${WITH_GPU}) +endif() cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator) diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index 3779571536..726b80c7db 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "gflags/gflags.h" #include "glog/logging.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) DECLARE_uint64(reallocate_gpu_memory_in_mb); #endif @@ -220,7 +220,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool( size_t allocate_bytes = max_chunk_size_; size_t index = 0; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (system_allocator_->UseGpu()) { if ((total_used_ + total_free_) == 0) { // Compute the allocation size for gpu for the first allocation. diff --git a/paddle/fluid/memory/detail/buddy_allocator_test.cc b/paddle/fluid/memory/detail/buddy_allocator_test.cc index 90f7e33eb3..2dc3e73af2 100644 --- a/paddle/fluid/memory/detail/buddy_allocator_test.cc +++ b/paddle/fluid/memory/detail/buddy_allocator_test.cc @@ -23,7 +23,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/platform/gpu_info.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include #include @@ -76,7 +76,7 @@ int* TestBuddyAllocator(BuddyAllocator* allocator, size_t size_bytes, return nullptr; } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(BuddyAllocator, GpuFraction) { // In a 16 GB machine, the pool size will be about 160 MB FLAGS_fraction_of_gpu_memory_to_use = 0.01; @@ -195,8 +195,13 @@ TEST(BuddyAllocator, AllocFromAvailable) { // Take half of available GPU void* p; +#ifdef PADDLE_WITH_HIP + hipError_t result = hipMalloc(&p, available >> 1); + EXPECT_TRUE(result == hipSuccess); +#else cudaError_t result = cudaMalloc(&p, available >> 1); EXPECT_TRUE(result == cudaSuccess); +#endif // BuddyAllocator should be able to alloc the remaining GPU BuddyAllocator buddy_allocator( @@ -209,7 +214,11 @@ TEST(BuddyAllocator, AllocFromAvailable) { TestBuddyAllocator(&buddy_allocator, static_cast(1 << 30)); if (p) { +#ifdef PADDLE_WITH_HIP + EXPECT_TRUE(hipFree(p) == hipSuccess); +#else EXPECT_TRUE(cudaFree(p) == cudaSuccess); +#endif } } @@ -219,7 +228,12 @@ TEST(BuddyAllocator, AllocFromAvailableWhenFractionIsOne) { FLAGS_reallocate_gpu_memory_in_mb = 0; void* p = nullptr; + +#ifdef PADDLE_WITH_HIP + EXPECT_TRUE(hipMalloc(&p, static_cast(1) << 30) == hipSuccess); +#else EXPECT_TRUE(cudaMalloc(&p, static_cast(1) << 30) == cudaSuccess); +#endif // BuddyAllocator should be able to alloc the remaining GPU BuddyAllocator buddy_allocator( @@ -230,7 +244,11 @@ TEST(BuddyAllocator, AllocFromAvailableWhenFractionIsOne) { TestBuddyAllocator(&buddy_allocator, static_cast(1) << 30); if (p) { +#ifdef PADDLE_WITH_HIP + EXPECT_TRUE(hipFree(p) == hipSuccess); +#else EXPECT_TRUE(cudaFree(p) == cudaSuccess); +#endif } } diff --git a/paddle/fluid/memory/detail/system_allocator.cc b/paddle/fluid/memory/detail/system_allocator.cc index 0fbbf405f0..4301ed4db1 100644 --- a/paddle/fluid/memory/detail/system_allocator.cc +++ b/paddle/fluid/memory/detail/system_allocator.cc @@ -35,7 +35,7 @@ limitations under the License. */ #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gpu_info.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif @@ -111,7 +111,7 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) { bool CPUAllocator::UseGpu() const { return false; } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void* GPUAllocator::Alloc(size_t* index, size_t size) { // CUDA documentation doesn't explain if cudaMalloc returns nullptr @@ -121,7 +121,7 @@ void* GPUAllocator::Alloc(size_t* index, size_t size) { void* p; auto result = platform::RecordedCudaMalloc(&p, size, gpu_id_); - if (result == cudaSuccess) { + if (result == gpuSuccess) { *index = 0; gpu_alloc_size_ += size; return p; @@ -193,10 +193,14 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) { } void* p; - // PINNED memory is visible to all CUDA contexts. +// PINNED memory is visible to all CUDA contexts. +#ifdef PADDLE_WITH_HIP + hipError_t result = hipHostMalloc(&p, size); +#else cudaError_t result = cudaHostAlloc(&p, size, cudaHostAllocPortable); +#endif - if (result == cudaSuccess) { + if (result == gpuSuccess) { *index = 1; // PINNED memory cuda_pinnd_alloc_size_ += size; return p; @@ -209,7 +213,7 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) { } void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { - cudaError_t err; + gpuError_t err; PADDLE_ENFORCE_EQ(index, 1, platform::errors::InvalidArgument( "The index should be 1, but got %d", index)); @@ -219,6 +223,15 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { "allocated cuda pinned memory (%d)", size, cuda_pinnd_alloc_size_)); cuda_pinnd_alloc_size_ -= size; +#ifdef PADDLE_WITH_HIP + err = hipHostFree(p); + if (err != hipErrorDeinitialized) { + PADDLE_ENFORCE_EQ( + err, hipSuccess, + platform::errors::Fatal( + "hipFreeHost failed in GPUPinnedAllocator, error code is %d", err)); + } +#else err = cudaFreeHost(p); // Purposefully allow cudaErrorCudartUnloading, because @@ -233,6 +246,7 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) { "cudaFreeHost failed in GPUPinnedAllocator, error code is %d", err)); } +#endif } bool CUDAPinnedAllocator::UseGpu() const { return false; } diff --git a/paddle/fluid/memory/detail/system_allocator.h b/paddle/fluid/memory/detail/system_allocator.h index 42f0f23ec1..e332bb670d 100644 --- a/paddle/fluid/memory/detail/system_allocator.h +++ b/paddle/fluid/memory/detail/system_allocator.h @@ -41,7 +41,7 @@ class CPUAllocator : public SystemAllocator { virtual bool UseGpu() const; }; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) class GPUAllocator : public SystemAllocator { public: explicit GPUAllocator(int gpu_id) : gpu_id_(gpu_id) {} diff --git a/paddle/fluid/memory/detail/system_allocator_test.cc b/paddle/fluid/memory/detail/system_allocator_test.cc index ea4897494f..13854d771a 100644 --- a/paddle/fluid/memory/detail/system_allocator_test.cc +++ b/paddle/fluid/memory/detail/system_allocator_test.cc @@ -56,7 +56,7 @@ TEST(CPUAllocator, LockMem) { TestAllocator(&a, 0); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(GPUAllocator, Alloc) { paddle::memory::detail::GPUAllocator a(0); TestAllocator(&a, 2048); @@ -77,7 +77,11 @@ TEST(GPUAllocator, AllocFailure) { allocator.Alloc(&index, alloc_size); ASSERT_TRUE(false); } catch (paddle::memory::allocation::BadAlloc&) { +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError()); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); +#endif } } #endif diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index b17da7f69a..cf5885f049 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -222,7 +222,7 @@ inline void SyncCUDAStream() { template <> void Copy( platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place, - const void* src, size_t num, cudaStream_t stream) { + const void* src, size_t num, gpuStream_t stream) { if (UNLIKELY(num == 0)) return; platform::SetDeviceId(src_place.device); @@ -244,7 +244,7 @@ void Copy( template <> void Copy( platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place, - const void* src, size_t num, cudaStream_t stream) { + const void* src, size_t num, gpuStream_t stream) { if (UNLIKELY(num == 0)) return; platform::SetDeviceId(dst_place.device); @@ -266,7 +266,7 @@ void Copy( template <> void Copy( platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place, - const void* src, size_t num, cudaStream_t stream) { + const void* src, size_t num, gpuStream_t stream) { if (UNLIKELY(num == 0)) return; VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " @@ -327,7 +327,7 @@ template <> void Copy( platform::CUDAPinnedPlace dst_place, void* dst, platform::CUDAPlace src_place, const void* src, size_t num, - cudaStream_t stream) { + gpuStream_t stream) { if (UNLIKELY(num == 0)) return; platform::SetDeviceId(src_place.device); VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to " @@ -345,7 +345,7 @@ template <> void Copy( platform::CUDAPlace dst_place, void* dst, platform::CUDAPinnedPlace src_place, const void* src, size_t num, - cudaStream_t stream) { + gpuStream_t stream) { if (UNLIKELY(num == 0)) return; platform::SetDeviceId(dst_place.device); diff --git a/paddle/fluid/memory/memcpy.h b/paddle/fluid/memory/memcpy.h index 7b2b8eb066..25490f28b6 100644 --- a/paddle/fluid/memory/memcpy.h +++ b/paddle/fluid/memory/memcpy.h @@ -33,7 +33,7 @@ namespace memory { template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) /** * \brief Copy memory from one place to another place. @@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); */ template void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, - cudaStream_t stream); + gpuStream_t stream); #endif } // namespace memory -- GitLab