From 9aaae254d5d4c46825d1627edff90c8e5bf9ee96 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 7 Jul 2022 12:38:14 +0800 Subject: [PATCH] Fix dev ctx with cuda graph (#44109) --- paddle/fluid/platform/CMakeLists.txt | 4 ++ .../platform/cuda_graph_with_memory_pool.cc | 12 +++- .../device_context_test_cuda_graph.cu | 39 +++++++++++++ paddle/phi/core/device_context.cc | 57 +++++++++++++++++++ paddle/phi/core/device_context.h | 27 +++++++++ .../fluid/tests/unittests/test_cuda_graph.py | 10 ++++ 6 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/platform/device_context_test_cuda_graph.cu diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index dc6911aecf1..efe04798712 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -297,6 +297,10 @@ if(WITH_GPU) device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) + nv_test( + device_context_test_cuda_graph + SRCS device_context_test_cuda_graph.cu + DEPS device_context gpu_info cuda_graph_with_memory_pool) nv_test( transform_test SRCS transform_test.cu diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index eb9f1ca845a..bfdf492962d 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -26,7 +26,9 @@ namespace platform { void BeginCUDAGraphCapture(platform::CUDAPlace place, cudaStreamCaptureMode mode, int64_t pool_id) { - auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto* dev_ctx = + reinterpret_cast(mutable_dev_ctx); dev_ctx->cudnn_workspace_handle().ResetWorkspace(); // After PR(#43206), cudnn related initializations will change to lazy mode. @@ -49,6 +51,9 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place, pool_id = CUDAGraph::SetMemoryPoolID(pool_id); memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph( pool_id); + dev_ctx->SetCUDAGraphAllocator(memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place) + .get()); if (old_value) { FLAGS_use_stream_safe_cuda_allocator = true; } @@ -60,8 +65,11 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place, std::unique_ptr EndCUDAGraphCapture() { auto place = CUDAGraph::CapturingPlace(); - auto *dev_ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto* dev_ctx = + reinterpret_cast(mutable_dev_ctx); dev_ctx->cudnn_workspace_handle().ResetWorkspace(); + dev_ctx->SetCUDAGraphAllocator(nullptr); return CUDAGraph::EndCapture(); } #endif diff --git a/paddle/fluid/platform/device_context_test_cuda_graph.cu b/paddle/fluid/platform/device_context_test_cuda_graph.cu new file mode 100644 index 00000000000..9f5a551743e --- /dev/null +++ b/paddle/fluid/platform/device_context_test_cuda_graph.cu @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "cuda.h" // NOLINT +#include "cuda_runtime.h" // NOLINT +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/fluid/platform/device_context.h" + +TEST(Device, DeviceContextWithCUDAGraph) { + using paddle::platform::CUDADeviceContext; + using paddle::platform::CUDAPlace; + using paddle::platform::DeviceContext; + using paddle::platform::DeviceContextPool; + using paddle::platform::Place; + + DeviceContextPool& pool = DeviceContextPool::Instance(); + Place place = CUDAPlace(0); + auto* dev_ctx = pool.Get(place); + + paddle::platform::BeginCUDAGraphCapture( + place, cudaStreamCaptureMode::cudaStreamCaptureModeThreadLocal, 0); + ASSERT_EQ(dev_ctx->IsCUDAGraphAllocatorValid(), true); + dev_ctx->GetCUDAGraphAllocator(); + paddle::platform::EndCUDAGraphCapture(); + ASSERT_EQ(dev_ctx->IsCUDAGraphAllocatorValid(), false); +} diff --git a/paddle/phi/core/device_context.cc b/paddle/phi/core/device_context.cc index ce57f4f627b..fc85fc32f62 100644 --- a/paddle/phi/core/device_context.cc +++ b/paddle/phi/core/device_context.cc @@ -14,6 +14,10 @@ #include "paddle/phi/core/device_context.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" +#endif + #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/selected_rows.h" @@ -58,6 +62,26 @@ struct DeviceContext::Impl { pinned_allocator_ = allocator; } +#ifdef PADDLE_WITH_CUDA + void SetCUDAGraphAllocator(const Allocator* allocator) { + // NOTE (Yuang): cuda graph allocator can be set to nullptr, so don't check + // validation of the allocator here + cuda_graph_allocator_ = allocator; + } + + const Allocator& GetCUDAGraphAllocator() const { + PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_, + phi::errors::InvalidArgument( + "Required cuda_graph_allocator_ shall not be " + "nullptr, but received nullptr.")); + return *cuda_graph_allocator_; + } + + bool IsCUDAGraphAllocatorValid() const { + return cuda_graph_allocator_ != nullptr; + } +#endif + const Allocator& GetAllocator() const { PADDLE_ENFORCE_NOT_NULL( device_allocator_, @@ -111,6 +135,17 @@ struct DeviceContext::Impl { auto* allocator = tensor->numel() == 0 ? zero_allocator_ : (pinned ? pinned_allocator_ : device_allocator_); +#ifdef PADDLE_WITH_CUDA + bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned; + if (must_cuda_graph_allocator && paddle::platform::is_gpu_place(place) && + paddle::platform::CUDAGraph::IsThisThreadCapturing()) { + PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_, + phi::errors::InvalidArgument( + "Required cuda_graph_allocator_ shall not be " + "nullptr, but received nullptr.")); + allocator = cuda_graph_allocator_; + } +#endif return tensor->AllocateFrom( const_cast(allocator), dtype, requested_size); } @@ -200,6 +235,9 @@ struct DeviceContext::Impl { const Allocator* host_allocator_{nullptr}; const Allocator* zero_allocator_{nullptr}; const Allocator* pinned_allocator_{nullptr}; +#ifdef PADDLE_WITH_CUDA + const Allocator* cuda_graph_allocator_{nullptr}; +#endif Generator* device_generator_{nullptr}; Generator* host_generator_{nullptr}; }; @@ -213,6 +251,11 @@ DeviceContext::DeviceContext(const DeviceContext& other) { impl_->SetPinnedAllocator(&other.GetPinnedAllocator()); impl_->SetHostGenerator(other.GetHostGenerator()); impl_->SetGenerator(other.GetGenerator()); +#ifdef PADDLE_WITH_CUDA + if (other.IsCUDAGraphAllocatorValid()) { + impl_->SetCUDAGraphAllocator(&other.GetCUDAGraphAllocator()); + } +#endif } DeviceContext::DeviceContext(DeviceContext&& other) { @@ -239,6 +282,20 @@ const Allocator& DeviceContext::GetHostAllocator() const { return impl_->GetHostAllocator(); } +#ifdef PADDLE_WITH_CUDA +void DeviceContext::SetCUDAGraphAllocator(const Allocator* allocator) { + impl_->SetCUDAGraphAllocator(allocator); +} + +const Allocator& DeviceContext::GetCUDAGraphAllocator() const { + return impl_->GetCUDAGraphAllocator(); +} + +bool DeviceContext::IsCUDAGraphAllocatorValid() const { + return impl_->IsCUDAGraphAllocatorValid(); +} +#endif + void DeviceContext::SetZeroAllocator(const Allocator* allocator) { impl_->SetZeroAllocator(allocator); } diff --git a/paddle/phi/core/device_context.h b/paddle/phi/core/device_context.h index 45e4fbf64dc..32dbb0c0a35 100644 --- a/paddle/phi/core/device_context.h +++ b/paddle/phi/core/device_context.h @@ -106,6 +106,33 @@ class PADDLE_API DeviceContext { const Allocator& GetPinnedAllocator() const; +#ifdef PADDLE_WITH_CUDA + /** + * @brief Set the CUDA graph Allocator object. + * + * @param allocator + */ + void SetCUDAGraphAllocator(const Allocator*); + + /** + * @brief Get the const CUDA graph Allocator object. + * + * @return Allocator + */ + const Allocator& GetCUDAGraphAllocator() const; + + /** + * @brief Test whether the CUDA graph allocator is valid + * + * This method should be called before calling GetCUDAGraphAllocator(). + * Other unit can calls GetCUDAGraphAllocator() method, + * only when this method returns True! + * + * @return true if cuda_graph_allocator_ is valid, false otherwise + */ + bool IsCUDAGraphAllocatorValid() const; +#endif + /** * @brief Allocate device memory for tensor. */ diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph.py index fda3fa79ef6..446a5500bc3 100644 --- a/python/paddle/fluid/tests/unittests/test_cuda_graph.py +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph.py @@ -236,6 +236,16 @@ class TestCUDAGraph(unittest.TestCase): self.assertTrue(np.array_equal(actual_x, x.numpy())) self.assertTrue(np.array_equal(actual_y, y.numpy())) + def test_dev_ctx_alloc(self): + if not can_use_cuda_graph(): + return + + x = paddle.to_tensor([2], dtype='float32') + graph = CUDAGraph() + graph.capture_begin() + y = paddle.cast(x, dtype='float16') + graph.capture_end() + if __name__ == "__main__": unittest.main() -- GitLab