From 62b1f38ccb1e1dbb7f255c35566c4bb5e0897f0d Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 10 Dec 2021 13:28:42 +0800 Subject: [PATCH] make cuda graph thread local allocator (#37814) --- .../memory/allocation/allocator_facade.cc | 5 ++- .../platform/device/gpu/cuda/cuda_graph.cc | 14 ++++++ .../platform/device/gpu/cuda/cuda_graph.h | 22 +++++++++ .../fluid/tests/unittests/test_cuda_graph.py | 45 ++++++++++++++++++- 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 2aed7ec001d..c836593f3f4 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -348,13 +348,14 @@ class AllocatorFacadePrivate { const AllocatorMap& GetAllocatorMap() { #ifdef PADDLE_WITH_CUDA - if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { + if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) { auto id = platform::CUDAGraph::CapturingID(); auto iter = cuda_graph_allocator_map_.find(id); PADDLE_ENFORCE_NE( iter, cuda_graph_allocator_map_.end(), platform::errors::PermissionDenied( "No memory pool is prepared for CUDA Graph capturing.")); + VLOG(10) << "Choose CUDA Graph memory pool to allocate memory"; return iter->second->allocators_; } else { return allocators_; @@ -405,7 +406,7 @@ class AllocatorFacadePrivate { #if defined(PADDLE_WITH_HIP) auto cuda_allocator = std::make_shared(p); cuda_allocators_[p][stream] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk_); + cuda_allocator, platform::GpuMinChunkSize(), 0, allow_free_idle_chunk_); #endif #if defined(PADDLE_WITH_CUDA) diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc index 3970acf82d3..8ee3b118c32 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc @@ -18,6 +18,7 @@ namespace paddle { namespace platform { std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; +paddle::optional CUDAGraph::capturing_thread_id_{paddle::none}; void CUDAGraph::Reset() { if (is_reset_) return; @@ -58,6 +59,13 @@ void CUDAGraph::BeginSegmentCapture() { IsCapturing(), true, errors::PermissionDenied("BeginSegmentCapture should be called when CUDA " "Graph is capturing.")); + if (IsThreadLocalCapturing()) { + PADDLE_ENFORCE_EQ(IsThisThreadCapturing(), true, + platform::errors::PermissionDenied( + "When capturing CUDA Graph in the thread local mode, " + "you cannot begin segmented capturing in the thread " + "which is not the one that starts the capturing.")); + } PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture( capturing_graph_->stream_, capturing_graph_->capture_mode_)); PADDLE_ENFORCE_EQ(IsValidCapturing(), true, @@ -82,6 +90,11 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, capturing_graph_->place_ = place; capturing_graph_->stream_ = stream; capturing_graph_->capture_mode_ = mode; + if (mode == cudaStreamCaptureModeThreadLocal) { + capturing_thread_id_ = std::this_thread::get_id(); + VLOG(10) << "Capturing CUDA Graph in thread local mode, thread id: " + << capturing_thread_id_; + } BeginSegmentCapture(); #endif } @@ -115,6 +128,7 @@ void CUDAGraph::EndSegmentCapture() { std::unique_ptr CUDAGraph::EndCapture() { EndSegmentCapture(); + capturing_thread_id_ = paddle::none; return std::move(capturing_graph_); } diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h index 0856e0fad19..ca1e7abb375 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "cuda.h" // NOLINT #include "cuda_runtime.h" // NOLINT @@ -26,6 +27,7 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" +#include "paddle/utils/optional.h" namespace paddle { namespace platform { @@ -99,6 +101,25 @@ class CUDAGraph { // supported during capturing CUDA Graph. static bool IsValidCapturing(); + static bool IsThreadLocalCapturing() { +#if CUDA_VERSION >= 10010 + return IsCapturing() && + capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal; +#else + return false; +#endif + } + + static bool IsThisThreadCapturing() { + if (UNLIKELY(IsCapturing())) { + return IsThreadLocalCapturing() + ? capturing_thread_id_.get() == std::this_thread::get_id() + : true; + } else { + return false; + } + } + private: static CUDAGraphID UniqueID() { static std::atomic id; @@ -118,6 +139,7 @@ class CUDAGraph { bool is_reset_{false}; std::mutex mtx_; + static paddle::optional capturing_thread_id_; static std::unique_ptr capturing_graph_; }; diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph.py index 8b4eae8ada4..66228856eff 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_graph.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph.py @@ -34,7 +34,8 @@ class TestCUDAGraph(unittest.TestCase): paddle.set_flags({ 'FLAGS_allocator_strategy': 'auto_growth', 'FLAGS_sync_nccl_allreduce': False, - 'FLAGS_cudnn_deterministic': True + 'FLAGS_cudnn_deterministic': True, + 'FLAGS_use_stream_safe_cuda_allocator': False, }) def random_tensor(self, shape): @@ -187,6 +188,48 @@ class TestCUDAGraph(unittest.TestCase): finally: graph.reset() + def test_dataloader(self): + if not can_use_cuda_graph(): + return + + class AutoIncDataset(paddle.io.Dataset): + def __init__(self, n, dtype): + self.n = n + self.dtype = dtype + + def __len__(self): + return self.n + + def __getitem__(self, idx): + return np.array([idx]).astype(self.dtype) + + n = 100 + dtype = 'int64' + dataset = AutoIncDataset(n, dtype) + data_loader = paddle.io.DataLoader( + dataset, batch_size=1, num_workers=2, use_buffer_reader=True) + x = None + y = None + + graph = None + for i, data in enumerate(data_loader): + if graph is None: + x = data + x = x.cuda() + graph = CUDAGraph() + graph.capture_begin() + y = x * x + graph.capture_end() + else: + x.copy_(data, False) + x = x.cuda() + + graph.replay() + actual_x = np.array([[i]]).astype(dtype) + actual_y = np.array([[i * i]]).astype(dtype) + self.assertTrue(np.array_equal(actual_x, x.numpy())) + self.assertTrue(np.array_equal(actual_y, y.numpy())) + if __name__ == "__main__": unittest.main() -- GitLab