未验证 提交 b9c464c3 编写于 作者: F From00 提交者: GitHub

Support multi-stream allocation for CUDA place (#37290)

* Support multi-stream allocation for CUDA place

* Do not notify the retrying from other streams when free CUDA allocation

* Fix compile error for CPU

* Fix compile error for HIP

* Release memory for StreamSafeCUDAAllocaRetry in malloc_test

* Add FLAGS_use_stream_safe_cuda_allocator

* Fix CI error for 'set_tests_properties'

* Invalidate stream safe CUDA allocator for naive_best_fit and thread_local strategy

* Performance improvement: insert allocation pair to outstanding_events_map when free but not alloc; replace recursive_mutex with SpinLock

* FLAGS priority changes: FLAGS_use_system_allocator > FLAGS_use_stream_safe_cuda_allocator

* Performance improvement: directly delete allocation when the recorded_streams is empty in FreeImpl of StreamSafeCUDAAllocator

* Add UT for alloc interface

* Changes multi-stream interface; move retry code from AllocatorFacadePrivate to StreamSafeCUDAAllocator
上级 adb54eb0
......@@ -17,6 +17,16 @@ if (WITH_GPU)
nv_test(malloc_test
SRCS malloc_test.cu
DEPS device_context malloc)
nv_test(stream_safe_cuda_alloc_test
SRCS stream_safe_cuda_alloc_test.cu
DEPS malloc)
if(WITH_TESTING AND TEST stream_safe_cuda_alloc_test)
set_tests_properties(stream_safe_cuda_alloc_test PROPERTIES
ENVIRONMENT "FLAGS_use_system_allocator=false"
ENVIRONMENT "FLAGS_enable_stream_safe_cuda_allocator=true"
ENVIRONMENT "FLAGS_allocator_strategy=auto_growth")
endif()
endif()
if (WITH_ROCM)
......
......@@ -15,8 +15,10 @@ endif()
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
nv_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
nv_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator)
nv_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator)
if(CUDA_VERSION GREATER_EQUAL 10.2)
nv_library(cuda_virtual_mem_allocator SRCS cuda_virtual_mem_allocator.cc DEPS dynload_cuda)
......@@ -25,8 +27,10 @@ endif()
if (WITH_ROCM)
hip_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
hip_library(pinned_allocator SRCS pinned_allocator.cc DEPS allocator)
hip_library(stream_safe_cuda_allocator SRCS stream_safe_cuda_allocator.cc DEPS allocator)
hip_library(thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator)
cc_test(thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator)
endif()
......@@ -38,7 +42,7 @@ endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
if (WITH_GPU OR WITH_ROCM)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator)
set(AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator stream_safe_cuda_allocator)
if(CUDA_VERSION GREATER_EQUAL 10.2)
list(APPEND AllocatorFacadeDeps cuda_virtual_mem_allocator)
endif()
......
......@@ -26,6 +26,7 @@
namespace paddle {
namespace memory {
namespace allocation {
#ifdef PADDLE_WITH_ASCEND_CL
using NPUPinnedAllocator = paddle::memory::allocation::NPUPinnedAllocator;
#endif
......@@ -40,26 +41,34 @@ using NPUPinnedAllocator = paddle::memory::allocation::NPUPinnedAllocator;
class AllocatorFacadePrivate;
class AllocatorFacade {
public:
~AllocatorFacade();
AllocatorFacade(const AllocatorFacade& o) = delete;
const AllocatorFacade& operator=(const AllocatorFacade& o) = delete;
~AllocatorFacade();
static AllocatorFacade& Instance();
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place);
// Allocate a shared allocation.
std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
size_t size);
// Allocate a unique allocation.
AllocationPtr Alloc(const platform::Place& place, size_t size);
// Release unused memory pool.
uint64_t Release(const platform::Place& place);
const std::shared_ptr<Allocator>& GetAllocator(const platform::Place& place);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::shared_ptr<Allocation> AllocShared(const platform::CUDAPlace& place,
size_t size,
const gpuStream_t& stream);
AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size,
const gpuStream_t& stream);
uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream);
void RecordStream(Allocation* allocation, const gpuStream_t& stream);
#ifdef PADDLE_WITH_CUDA
void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id);
void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id);
#endif
#endif
// TODO(yy): Allocate a Copy-On-Write allocation?
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
StreamSafeCUDAAllocation::StreamSafeCUDAAllocation(
AllocationPtr underlying_allocation, gpuStream_t owning_stream)
: Allocation(underlying_allocation->ptr(), underlying_allocation->size(),
underlying_allocation->place()),
underlying_allocation_(std::move(underlying_allocation)),
owning_stream_(owning_stream),
recorded_streams_(std::make_shared<std::set<gpuStream_t>>()) {}
void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) {
VLOG(8) << "Record stream " << stream << " to " << ptr();
if (stream == owning_stream_) {
return;
}
std::lock_guard<SpinLock> lock_guard(spin_lock_);
recorded_streams_->insert(stream);
}
std::shared_ptr<std::set<gpuStream_t>>
StreamSafeCUDAAllocation::GetRecordedStreams() {
return recorded_streams_;
}
StreamSafeCUDAAllocator::StreamSafeCUDAAllocator(
const std::shared_ptr<Allocator>& underlying_allocator,
const platform::CUDAPlace& place, const gpuStream_t default_stream)
: underlying_allocator_(underlying_allocator),
place_(place),
default_stream_(default_stream) {
std::lock_guard<SpinLock> lock_guard(allocators_map_lock_);
allocators_map_[place].emplace_back(this);
}
StreamSafeCUDAAllocator::~StreamSafeCUDAAllocator() {
std::lock_guard<SpinLock> lock_guard(allocators_map_lock_);
std::vector<StreamSafeCUDAAllocator*>& allocators = allocators_map_[place_];
allocators.erase(std::remove(allocators.begin(), allocators.end(), this),
allocators.end());
}
bool StreamSafeCUDAAllocator::IsAllocThreadSafe() const { return true; }
Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) {
ProcessEventsAndFree();
AllocationPtr underlying_allocation;
try {
underlying_allocation = underlying_allocator_->Allocate(size);
} catch (BadAlloc&) {
VLOG(9) << "Allocation failed when allocating " << size << " bytes";
uint64_t release_size = ReleaseImpl(place_);
VLOG(9) << "Release " << release_size << " bytes memory from all streams";
try {
underlying_allocation = underlying_allocator_->Allocate(size);
} catch (...) {
VLOG(9) << "Still allocation failed after release memory";
throw;
}
} catch (...) {
throw;
}
StreamSafeCUDAAllocation* allocation = new StreamSafeCUDAAllocation(
std::move(underlying_allocation), default_stream_);
return allocation;
}
void StreamSafeCUDAAllocator::FreeImpl(Allocation* allocation) {
if (dynamic_cast<StreamSafeCUDAAllocation*>(allocation)
->GetRecordedStreams()
->empty()) {
delete allocation;
} else {
std::lock_guard<SpinLock> lock_guard(outstanding_events_map_lock_);
FreeStreamSafeCUDAAllocation(allocation);
}
}
uint64_t StreamSafeCUDAAllocator::ReleaseImpl(const platform::Place& place) {
std::lock_guard<SpinLock> lock_guard(allocators_map_lock_);
std::vector<StreamSafeCUDAAllocator*>& allocators =
allocators_map_[BOOST_GET_CONST(platform::CUDAPlace, place)];
uint64_t release_size = 0;
for (StreamSafeCUDAAllocator* allocator : allocators) {
release_size += allocator->ProcessEventsAndFreeWithRelease();
}
return release_size;
}
void StreamSafeCUDAAllocator::CreateEventForAllRecordedStream(
std::set<gpuStream_t>* recorded_streams,
std::deque<gpuEvent_t>* outstanding_events) {
for (gpuStream_t stream : *recorded_streams) {
gpuEvent_t event;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
hipEventCreateWithFlags(&event, hipEventDisableTiming));
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, stream));
#endif
outstanding_events->emplace_back(event);
VLOG(9) << "Record event " << event << " in stream " << stream;
}
recorded_streams->clear();
}
void StreamSafeCUDAAllocator::FreeStreamSafeCUDAAllocation(
Allocation* allocation) {
std::deque<gpuEvent_t>& outstanding_events =
outstanding_events_map_[allocation];
CreateEventForAllRecordedStream(
dynamic_cast<StreamSafeCUDAAllocation*>(allocation)
->GetRecordedStreams()
.get(),
&outstanding_events);
if (!outstanding_events.empty()) {
VLOG(8) << allocation->ptr() << " is not ready to free";
return;
}
VLOG(8) << "Free " << allocation->ptr();
outstanding_events_map_.erase(allocation);
delete allocation;
}
void StreamSafeCUDAAllocator::ProcessEventsAndFree() {
std::lock_guard<SpinLock> lock_guard(outstanding_events_map_lock_);
for (auto map_it = outstanding_events_map_.begin();
map_it != outstanding_events_map_.end();) {
std::deque<gpuEvent_t>& outstanding_events = map_it->second;
VLOG(10) << "Check " << outstanding_events.size()
<< " outstanding events for " << map_it->first->ptr();
auto deque_it = outstanding_events.begin();
while (deque_it != outstanding_events.end()) {
#ifdef PADDLE_WITH_CUDA
gpuError_t err = cudaEventQuery(*deque_it);
if (err == cudaErrorNotReady) {
VLOG(10) << "Event " << *deque_it << " for " << map_it->first->ptr()
<< " is not completed";
outstanding_events.erase(outstanding_events.begin(), deque_it);
break;
}
PADDLE_ENFORCE_CUDA_SUCCESS(err);
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(*deque_it));
#else
gpuError_t err = hipEventQuery(*deque_it);
if (err == hipErrorNotReady) {
VLOG(10) << "Event " << *deque_it << " for " << map_it->first->ptr()
<< " is not completed";
// Erase the completded event before "deque_it"
outstanding_events.erase(outstanding_events.begin(), deque_it);
break;
}
PADDLE_ENFORCE_CUDA_SUCCESS(err);
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventDestroy(*deque_it));
#endif
++deque_it;
}
if (deque_it == outstanding_events.end()) {
outstanding_events.clear();
Allocation* allocation = map_it->first;
// "map_it" may be invalid after calling FreeStreamSafeCUDAAllocation
auto next_it = ++map_it;
FreeStreamSafeCUDAAllocation(allocation);
map_it = next_it;
} else {
++map_it;
}
}
}
uint64_t StreamSafeCUDAAllocator::ProcessEventsAndFreeWithRelease() {
ProcessEventsAndFree();
return underlying_allocator_->Release(place_);
}
std::map<platform::CUDAPlace, std::vector<StreamSafeCUDAAllocator*>>
StreamSafeCUDAAllocator::allocators_map_;
SpinLock StreamSafeCUDAAllocator::allocators_map_lock_;
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#else
#include <hip/hip_runtime.h>
#endif
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/allocation/spin_lock.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
namespace allocation {
class StreamSafeCUDAAllocation : public Allocation {
public:
StreamSafeCUDAAllocation(AllocationPtr underlying_allocation,
gpuStream_t owning_stream);
void RecordStream(gpuStream_t stream);
std::shared_ptr<std::set<gpuStream_t>> GetRecordedStreams();
private:
AllocationPtr underlying_allocation_;
gpuStream_t owning_stream_;
std::shared_ptr<std::set<gpuStream_t>> recorded_streams_;
SpinLock spin_lock_;
};
class StreamSafeCUDAAllocator : public Allocator {
public:
StreamSafeCUDAAllocator(
const std::shared_ptr<Allocator> &underlying_allocator,
const platform::CUDAPlace &place, const gpuStream_t default_stream);
~StreamSafeCUDAAllocator();
bool IsAllocThreadSafe() const override;
protected:
Allocation *AllocateImpl(size_t size) override;
void FreeImpl(Allocation *allocation) override;
uint64_t ReleaseImpl(const platform::Place &place) override;
private:
void CreateEventForAllRecordedStream(
std::set<gpuStream_t> *recorded_streams,
std::deque<gpuEvent_t> *outstanding_events);
void FreeStreamSafeCUDAAllocation(Allocation *allocation);
void ProcessEventsAndFree();
uint64_t ProcessEventsAndFreeWithRelease();
static std::map<platform::CUDAPlace, std::vector<StreamSafeCUDAAllocator *>>
allocators_map_;
static SpinLock allocators_map_lock_;
std::shared_ptr<Allocator> underlying_allocator_;
platform::CUDAPlace place_;
gpuStream_t default_stream_;
std::map<Allocation *, std::deque<gpuEvent_t>> outstanding_events_map_;
SpinLock outstanding_events_map_lock_;
};
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -20,18 +20,40 @@ limitations under the License. */
namespace paddle {
namespace memory {
std::shared_ptr<Allocation> AllocShared(const platform::Place &place,
std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
size_t size) {
return allocation::AllocatorFacade::Instance().AllocShared(place, size);
}
AllocationPtr Alloc(const platform::Place &place, size_t size) {
AllocationPtr Alloc(const platform::Place& place, size_t size) {
return allocation::AllocatorFacade::Instance().Alloc(place, size);
}
uint64_t Release(const platform::Place &place) {
uint64_t Release(const platform::Place& place) {
return allocation::AllocatorFacade::Instance().Release(place);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::shared_ptr<Allocation> AllocShared(const platform::CUDAPlace& place,
size_t size,
const gpuStream_t& stream) {
return allocation::AllocatorFacade::Instance().AllocShared(place, size,
stream);
}
AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size,
const gpuStream_t& stream) {
return allocation::AllocatorFacade::Instance().Alloc(place, size, stream);
}
uint64_t Release(const platform::CUDAPlace& place, const gpuStream_t& stream) {
return allocation::AllocatorFacade::Instance().Release(place, stream);
}
void RecordStream(Allocation* allocation, const gpuStream_t& stream) {
return allocation::AllocatorFacade::Instance().RecordStream(allocation,
stream);
}
#endif
} // namespace memory
} // namespace paddle
......@@ -40,5 +40,18 @@ extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size);
extern uint64_t Release(const platform::Place& place);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
extern std::shared_ptr<Allocation> AllocShared(const platform::CUDAPlace& place,
size_t size,
const gpuStream_t& stream);
extern AllocationPtr Alloc(const platform::CUDAPlace& place, size_t size,
const gpuStream_t& stream);
extern uint64_t Release(const platform::CUDAPlace& place,
const gpuStream_t& stream);
void RecordStream(Allocation* allocation, const gpuStream_t& stream);
#endif
} // namespace memory
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace memory {
__global__ void add_kernel(int *x, int n) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
atomicAdd(x + i, tid);
}
}
class StreamSafeCUDAAllocTest : public ::testing::Test {
protected:
void SetUp() override {
place_ = platform::CUDAPlace();
stream_num_ = 64;
grid_num_ = 1;
block_num_ = 64;
data_num_ = 64;
default_stream = nullptr;
streams_.reserve(stream_num_);
streams_.emplace_back(default_stream);
for (size_t i = 1; i < stream_num_; ++i) {
gpuStream_t stream;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream));
#endif
streams_.emplace_back(stream);
}
for (size_t i = 0; i < stream_num_; ++i) {
size_t allocation_size = data_num_ * sizeof(int);
std::shared_ptr<Allocation> allocation =
AllocShared(place_, allocation_size, streams_[i]);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemset(allocation->ptr(), 0, allocation->size()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
hipMemset(allocation->ptr(), 0, allocation->size()));
#endif
allocations_.emplace_back(allocation);
}
}
void SingleStreamRun(size_t idx) {
for (size_t i = 0; i < stream_num_; ++i) {
int *x = reinterpret_cast<int *>(allocations_[i]->ptr());
add_kernel<<<grid_num_, block_num_, 0, streams_[idx]>>>(x, data_num_);
if (i != idx) {
RecordStream(allocations_[i].get(), streams_[idx]);
}
}
}
void MultiStreamRun() {
for (int i = 0; i < stream_num_; ++i) {
SingleStreamRun(i);
}
allocations_.clear(); // fast_gc
}
void MultiThreadMUltiStreamRun() {
std::vector<std::thread> threads;
for (size_t i = 0; i < stream_num_; ++i) {
threads.push_back(
std::thread(&StreamSafeCUDAAllocTest::SingleStreamRun, this, i));
}
for (size_t i = 0; i < stream_num_; ++i) {
threads[i].join();
}
allocations_.clear(); // fast_gc
}
void CheckResult() {
auto host_x = std::unique_ptr<int[]>(new int[data_num_]);
size_t thread_num = grid_num_ * block_num_;
for (int i = 0; i < stream_num_; ++i) {
// tricky code, the allocations are still accessible even though
// allocations_.clear() has been called
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(host_x.get(), allocations_[i]->ptr(),
data_num_ * sizeof(int), cudaMemcpyDeviceToHost));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
hipMemcpy(host_x.get(), allocations_[i]->ptr(),
data_num_ * sizeof(int), hipMemcpyDeviceToHost));
#endif
for (int j = 0; j < data_num_; ++j) {
EXPECT_TRUE(host_x[j] == (j % thread_num) * stream_num_);
}
}
}
void TearDown() override {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceSynchronize());
#endif
for (gpuStream_t stream : streams_) {
Release(place_, stream);
}
for (size_t i = 1; i < stream_num_; ++i) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(streams_[i]));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamDestroy(streams_[i]));
#endif
}
uint64_t cuda_malloc_size =
platform::RecordedCudaMallocSize(place_.GetDeviceId());
ASSERT_EQ(cuda_malloc_size, 0) << "Found " << cuda_malloc_size
<< " bytes memory that not released yet,"
<< " there may be a memory leak problem";
}
size_t stream_num_;
size_t grid_num_;
size_t block_num_;
size_t data_num_;
platform::CUDAPlace place_;
gpuStream_t default_stream;
std::vector<gpuStream_t> streams_;
std::vector<std::shared_ptr<Allocation>> allocations_;
};
TEST_F(StreamSafeCUDAAllocTest, CUDAMutilStreamTest) {
MultiStreamRun();
CheckResult();
}
TEST_F(StreamSafeCUDAAllocTest, CUDAMutilThreadMutilStreamTest) {
MultiThreadMUltiStreamRun();
CheckResult();
}
TEST(StreamSafeCUDAAllocInterfaceTest, AllocInterfaceTest) {
platform::CUDAPlace place = platform::CUDAPlace();
size_t alloc_size = 256;
std::shared_ptr<Allocation> allocation_implicit_stream =
AllocShared(place, alloc_size);
EXPECT_GE(allocation_implicit_stream->size(), alloc_size);
void *address = allocation_implicit_stream->ptr();
allocation_implicit_stream.reset();
gpuStream_t default_stream = nullptr;
allocation::AllocationPtr allocation_unique =
Alloc(place, alloc_size, default_stream);
EXPECT_GE(allocation_unique->size(), alloc_size);
EXPECT_EQ(allocation_unique->ptr(), address);
}
TEST(StreamSafeCUDAAllocRetryTest, RetryTest) {
platform::CUDAPlace place = platform::CUDAPlace();
gpuStream_t stream1, stream2;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream1));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream2));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream1));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream2));
#endif
size_t available_size = platform::GpuAvailableMemToAlloc();
// alloc_size < available_size < 2 * alloc_size
size_t alloc_size = available_size / 4 * 3;
std::shared_ptr<Allocation> allocation1 =
AllocShared(place, alloc_size, stream1);
std::shared_ptr<Allocation> allocation2;
std::thread th([&allocation2, &place, &stream2, alloc_size]() {
std::this_thread::sleep_for(std::chrono::seconds(1));
allocation2 = AllocShared(place, alloc_size, stream2);
});
allocation1.reset(); // free but not release
th.join();
EXPECT_GE(allocation2->size(), alloc_size);
allocation2.reset();
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_CUDA_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_CUDA_SUCCESS(hipDeviceSynchronize());
#endif
Release(place, stream1);
Release(place, stream2);
}
} // namespace memory
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册