From 21b93c3dc68c616f12c360ebbbd9961fe379902f Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 29 Sep 2021 17:12:17 +0800 Subject: [PATCH] Add basic support for CUDA Graph (#36190) * add basic support for CUDA Graph * fix ci compile error * fix LOG print, fix windows CI * follow comments and update * small fix for default ctor * fix rocm compile error * fix CPU compile error --- paddle/fluid/memory/allocation/CMakeLists.txt | 6 +- .../memory/allocation/allocator_facade.cc | 147 ++++++++++++++++-- .../memory/allocation/allocator_facade.h | 8 + .../auto_growth_best_fit_allocator.cc | 8 +- .../auto_growth_best_fit_allocator.h | 3 +- paddle/fluid/platform/CMakeLists.txt | 5 + paddle/fluid/platform/cuda_graph.cc | 92 +++++++++++ paddle/fluid/platform/cuda_graph.h | 136 ++++++++++++++++ .../platform/cuda_graph_with_memory_pool.cc | 43 +++++ .../platform/cuda_graph_with_memory_pool.h | 64 ++++++++ paddle/fluid/platform/gpu_info.cc | 2 + paddle/fluid/platform/type_defs.h | 1 + paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 15 ++ python/paddle/device/cuda/graphs.py | 57 +++++++ .../fluid/tests/unittests/test_cuda_graph.py | 60 +++++++ 16 files changed, 634 insertions(+), 15 deletions(-) create mode 100644 paddle/fluid/platform/cuda_graph.cc create mode 100644 paddle/fluid/platform/cuda_graph.h create mode 100644 paddle/fluid/platform/cuda_graph_with_memory_pool.cc create mode 100644 paddle/fluid/platform/cuda_graph_with_memory_pool.h create mode 100644 python/paddle/device/cuda/graphs.py create mode 100644 python/paddle/fluid/tests/unittests/test_cuda_graph.py diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 6b4afae9f8..4aa1900f53 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -82,7 +82,11 @@ endif() cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator) cc_test(test_aligned_allocator SRCS test_aligned_allocator.cc DEPS aligned_allocator) cc_library(allocator_strategy SRCS allocator_strategy.cc DEPS gflags ${AllocatorFacadeDeps}) -cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy ) +cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy) + +if (WITH_GPU) + target_link_libraries(allocator_facade cuda_graph) +endif() cc_test(retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator) if (WITH_TESTING) diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 78bce53b6f..0388e2d13a 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -32,6 +32,9 @@ #include "paddle/fluid/memory/allocation/thread_local_allocator.h" #include "paddle/fluid/platform/gpu_info.h" #endif +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_graph.h" +#endif #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/xpu/xpu_info.h" #endif @@ -47,17 +50,64 @@ PADDLE_DEFINE_EXPORTED_bool( "Whether to use system allocator to allocate CPU and GPU memory. " "Only used for unittests."); +DECLARE_string(allocator_strategy); + namespace paddle { namespace memory { namespace allocation { +#ifdef PADDLE_WITH_CUDA +class CUDAGraphAllocator + : public Allocator, + public std::enable_shared_from_this { + private: + class PrivateAllocation : public Allocation { + public: + PrivateAllocation(CUDAGraphAllocator* allocator, + AllocationPtr underlying_allocation) + : Allocation(underlying_allocation->ptr(), + underlying_allocation->size(), + underlying_allocation->place()), + allocator_(allocator->shared_from_this()), + underlying_allocation_(std::move(underlying_allocation)) {} + + private: + std::shared_ptr allocator_; + AllocationPtr underlying_allocation_; + }; + + explicit CUDAGraphAllocator(const std::shared_ptr& allocator) + : underlying_allocator_(allocator) {} + + public: + static std::shared_ptr Create( + const std::shared_ptr& allocator) { + return std::shared_ptr(new CUDAGraphAllocator(allocator)); + } + + protected: + Allocation* AllocateImpl(size_t size) { + VLOG(10) << "Allocate " << size << " for CUDA Graph"; + return new PrivateAllocation(this, underlying_allocator_->Allocate(size)); + } + + void FreeImpl(Allocation* allocation) { + VLOG(10) << "delete for CUDA Graph"; + delete allocation; + } + + private: + std::shared_ptr underlying_allocator_; +}; +#endif + class AllocatorFacadePrivate { public: using AllocatorMap = std::map>; - AllocatorFacadePrivate() { - auto strategy = GetAllocatorStrategy(); - switch (strategy) { + explicit AllocatorFacadePrivate(bool allow_free_idle_chunk = true) { + strategy_ = GetAllocatorStrategy(); + switch (strategy_) { case AllocatorStrategy::kNaiveBestFit: { InitNaiveBestFitCPUAllocator(); #ifdef PADDLE_WITH_XPU @@ -91,7 +141,8 @@ class AllocatorFacadePrivate { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) { - InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id)); + InitAutoGrowthCUDAAllocator(platform::CUDAPlace(dev_id), + allow_free_idle_chunk); } InitNaiveBestFitCUDAPinnedAllocator(); #endif @@ -117,7 +168,7 @@ class AllocatorFacadePrivate { default: { PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported allocator strategy: %d", static_cast(strategy))); + "Unsupported allocator strategy: %d", static_cast(strategy_))); } } InitZeroSizeAllocators(); @@ -130,11 +181,29 @@ class AllocatorFacadePrivate { CheckAllocThreadSafe(); } + inline const AllocatorMap& GetAllocatorMap() { +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(platform::CUDAGraph::IsCapturing())) { + auto id = platform::CUDAGraph::CapturingID(); + auto iter = cuda_graph_allocator_map_.find(id); + PADDLE_ENFORCE_NE( + iter, cuda_graph_allocator_map_.end(), + platform::errors::PermissionDenied( + "No memory pool is prepared for CUDA Graph capturing.")); + return iter->second->allocators_; + } else { + return allocators_; + } +#else + return allocators_; +#endif + } + inline const std::shared_ptr& GetAllocator( const platform::Place& place, size_t size) { const auto& allocators = (size > 0 ? (UNLIKELY(FLAGS_use_system_allocator) ? system_allocators_ - : allocators_) + : GetAllocatorMap()) : zero_size_allocators_); auto iter = allocators.find(place); PADDLE_ENFORCE_NE(iter, allocators.end(), @@ -145,6 +214,7 @@ class AllocatorFacadePrivate { private: void InitSystemAllocators() { + if (!system_allocators_.empty()) return; system_allocators_[platform::CPUPlace()] = std::make_shared(); #ifdef PADDLE_WITH_XPU int device_count = platform::GetXPUDeviceCount(); @@ -183,10 +253,11 @@ class AllocatorFacadePrivate { allocators_[p] = std::make_shared(p); } - void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p) { + void InitAutoGrowthCUDAAllocator(platform::CUDAPlace p, + bool allow_free_idle_chunk) { auto cuda_allocator = std::make_shared(p); allocators_[p] = std::make_shared( - cuda_allocator, platform::GpuMinChunkSize()); + cuda_allocator, platform::GpuMinChunkSize(), allow_free_idle_chunk); } #endif @@ -226,6 +297,7 @@ class AllocatorFacadePrivate { }; void InitZeroSizeAllocators() { + if (!zero_size_allocators_.empty()) return; std::vector places; places.emplace_back(platform::CPUPlace()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -279,12 +351,57 @@ class AllocatorFacadePrivate { } } +#ifdef PADDLE_WITH_CUDA + + public: + void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { + PADDLE_ENFORCE_EQ(strategy_, AllocatorStrategy::kAutoGrowth, + platform::errors::InvalidArgument( + "CUDA Graph is only supported when the " + "FLAGS_allocator_strategy=\"auto_growth\", but got " + "FLAGS_allocator_strategy=\"%s\"", + FLAGS_allocator_strategy)); + auto& allocator = cuda_graph_allocator_map_[id]; + PADDLE_ENFORCE_EQ( + allocator.get(), nullptr, + platform::errors::InvalidArgument( + "The memory pool of the CUDA Graph with ID %d have been prepared.", + id)); + allocator.reset( + new AllocatorFacadePrivate(/*allow_free_idle_chunk=*/false)); + for (auto& item : allocator->allocators_) { + auto& old_allocator = item.second; + old_allocator = CUDAGraphAllocator::Create(old_allocator); + } + VLOG(10) << "Prepare memory pool for CUDA Graph with ID " << id; + } + + void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { + auto iter = cuda_graph_allocator_map_.find(id); + PADDLE_ENFORCE_NE(iter, cuda_graph_allocator_map_.end(), + platform::errors::InvalidArgument( + "Cannot find CUDA Graph with ID = %d", id)); + cuda_graph_allocator_map_.erase(iter); + VLOG(10) << "Remove memory pool of CUDA Graph with ID " << id; + } +#endif + private: AllocatorMap allocators_; - AllocatorMap zero_size_allocators_; - AllocatorMap system_allocators_; +#ifdef PADDLE_WITH_CUDA + std::unordered_map> + cuda_graph_allocator_map_; +#endif + AllocatorStrategy strategy_; + + static AllocatorMap zero_size_allocators_; + static AllocatorMap system_allocators_; }; +AllocatorFacadePrivate::AllocatorMap + AllocatorFacadePrivate::zero_size_allocators_; +AllocatorFacadePrivate::AllocatorMap AllocatorFacadePrivate::system_allocators_; + // Pimpl. Make interface clean. AllocatorFacade::AllocatorFacade() : m_(new AllocatorFacadePrivate()) {} // delete m_ may cause core dump when the destructor of python in conflict with @@ -316,6 +433,16 @@ const std::shared_ptr& AllocatorFacade::GetAllocator( return m_->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); } +#ifdef PADDLE_WITH_CUDA +void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(CUDAGraphID id) { + return m_->PrepareMemoryPoolForCUDAGraph(id); +} + +void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id) { + return m_->RemoveMemoryPoolOfCUDAGraph(id); +} +#endif + } // namespace allocation } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 7f6ad561aa..8d889ec38e 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -18,6 +18,9 @@ #ifdef PADDLE_WITH_ASCEND_CL #include "paddle/fluid/memory/allocation/npu_pinned_allocator.h" #endif +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/gpu_info.h" +#endif #include "paddle/fluid/platform/place.h" namespace paddle { @@ -54,6 +57,11 @@ class AllocatorFacade { uint64_t Release(const platform::Place& place); const std::shared_ptr& GetAllocator(const platform::Place& place); +#ifdef PADDLE_WITH_CUDA + void PrepareMemoryPoolForCUDAGraph(CUDAGraphID id); + void RemoveMemoryPoolOfCUDAGraph(CUDAGraphID id); +#endif + // TODO(yy): Allocate a Copy-On-Write allocation? private: AllocatorFacade(); diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc index a35d8a73f7..f36d589f90 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc @@ -39,11 +39,12 @@ namespace allocation { AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator( const std::shared_ptr &underlying_allocator, size_t alignment, - size_t chunk_size) + size_t chunk_size, bool allow_free_idle_chunk) : underlying_allocator_( std::make_shared(underlying_allocator, alignment)), alignment_(alignment), - chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)) {} + chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)), + allow_free_idle_chunk_(allow_free_idle_chunk) {} Allocation *AutoGrowthBestFitAllocator::AllocateImpl(size_t size) { size = AlignedSize(size, alignment_); @@ -139,6 +140,9 @@ void AutoGrowthBestFitAllocator::FreeImpl(Allocation *allocation) { } uint64_t AutoGrowthBestFitAllocator::FreeIdleChunks() { + if (!allow_free_idle_chunk_) { + return 0; + } uint64_t bytes = 0; for (auto chunk_it = chunks_.begin(); chunk_it != chunks_.end();) { auto &blocks = chunk_it->blocks_; diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h index 5ed6eb94f1..d1fa6cce01 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h @@ -31,7 +31,7 @@ class AutoGrowthBestFitAllocator : public Allocator { public: AutoGrowthBestFitAllocator( const std::shared_ptr &underlying_allocator, size_t alignment, - size_t chunk_size = 0); + size_t chunk_size = 0, bool allow_free_idle_chunk = true); bool IsAllocThreadSafe() const override { return true; } @@ -86,6 +86,7 @@ class AutoGrowthBestFitAllocator : public Allocator { std::list chunks_; size_t alignment_; size_t chunk_size_; + bool allow_free_idle_chunk_; SpinLock spinlock_; }; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 2540170ed5..21213f9e6f 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -59,9 +59,14 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS}) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) IF(WITH_GPU) + nv_library(cuda_graph SRCS cuda_graph.cc DEPS enforce allocator_facade) nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda) nv_library(cuda_profiler SRCS cuda_profiler.cc DEPS enforce) + nv_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade cuda_graph) +ELSE() + cc_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade) ENDIF() + IF(WITH_ROCM) hip_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda) ENDIF() diff --git a/paddle/fluid/platform/cuda_graph.cc b/paddle/fluid/platform/cuda_graph.cc new file mode 100644 index 0000000000..6e518d779e --- /dev/null +++ b/paddle/fluid/platform/cuda_graph.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/cuda_graph.h" + +namespace paddle { +namespace platform { + +std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; + +void CUDAGraph::Reset() { + if (is_reset_) return; +#if CUDA_VERSION >= 10010 + if (graph_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphDestroy(graph_)); + graph_ = nullptr; + } + if (exec_graph_) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphExecDestroy(exec_graph_)); + exec_graph_ = nullptr; + } +#endif + // callback should be called in reverse order because the latter added + // callback may rely on the former added callback. + for (auto iter = callbacks_.rbegin(); iter != callbacks_.rend(); ++iter) { + (*iter)(); + } + callbacks_.clear(); + is_reset_ = true; +} + +void CUDAGraph::Replay() { +#if CUDA_VERSION >= 10010 + PADDLE_ENFORCE_EQ(is_reset_, false, + errors::PermissionDenied( + "Cannot replay the CUDA Graph after reset is called.")); + PADDLE_ENFORCE_NOT_NULL(exec_graph_, + errors::PermissionDenied( + "CUDA Graph must be captured before replaying.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaGraphLaunch(exec_graph_, stream_)); +#endif +} + +void CUDAGraph::BeginCapture(platform::CUDAPlace place, cudaStream_t stream, + cudaStreamCaptureMode mode) { + ThrowErrorIfNotSupportCUDAGraph(); + PADDLE_ENFORCE_EQ( + IsCapturing(), false, + errors::PermissionDenied("CUDA Graph can only captured one by one.")); + PADDLE_ENFORCE_NOT_NULL( + stream, errors::PermissionDenied( + "CUDA Graph cannot be captured in default CUDA stream 0.")); + capturing_graph_.reset(new CUDAGraph()); + capturing_graph_->place_ = place; + capturing_graph_->stream_ = stream; + + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaStreamBeginCapture(capturing_graph_->stream_, mode)); + cudaStreamCaptureStatus status; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamGetCaptureInfo( + capturing_graph_->stream_, &status, &(capturing_graph_->id_))); + VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_; +} + +std::unique_ptr CUDAGraph::EndCapture() { + ThrowErrorIfNotSupportCUDAGraph(); +#if CUDA_VERSION >= 10010 + PADDLE_ENFORCE_EQ(IsCapturing(), true, + errors::PermissionDenied("No CUDA Graph is capturing.")); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamEndCapture( + capturing_graph_->stream_, &(capturing_graph_->graph_))); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaGraphInstantiate(&(capturing_graph_->exec_graph_), + capturing_graph_->graph_, nullptr, nullptr, 0)); + VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_; + return std::move(capturing_graph_); +#endif +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph.h b/paddle/fluid/platform/cuda_graph.h new file mode 100644 index 0000000000..41e36049aa --- /dev/null +++ b/paddle/fluid/platform/cuda_graph.h @@ -0,0 +1,136 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "cuda.h" // NOLINT +#include "cuda_runtime.h" // NOLINT +#include "paddle/fluid/platform/type_defs.h" + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace platform { + +#if CUDA_VERSION >= 10010 +static void ThrowErrorIfNotSupportCUDAGraph() {} +#else +enum cudaStreamCaptureMode { + cudaStreamCaptureModeGlobal = 0, + cudaStreamCaptureModeThreadLocal = 1, + cudaStreamCaptureModeRelaxed = 2 +}; +static void ThrowErrorIfNotSupportCUDAGraph() { + PADDLE_THROW(platform::errors::Unimplemented( + "CUDA Graph is only supported when CUDA version >= 10.1")); +} +#endif + +// NOTE: Currently, we do not support to capture CUDA graph in parallel +// NOTE: Do not use this class directly because it should be used with +// the memory pool. +class CUDAGraph { + DISABLE_COPY_AND_ASSIGN(CUDAGraph); + + // Since the constructor would throw error is CUDA_VERSION < 10010. + // The non-static method of CUDAGraph need not check CUDA_VERSION + // again. + CUDAGraph() { ThrowErrorIfNotSupportCUDAGraph(); } + + public: + ~CUDAGraph() { Reset(); } + + CUDAGraphID ID() const { return id_; } + + void Replay(); + + void Reset(); + + void AddResetCallback(std::function callback) { + std::lock_guard guard(mtx_); + callbacks_.push_back(std::move(callback)); + } + + static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream, + cudaStreamCaptureMode mode); + static std::unique_ptr EndCapture(); + static void AddResetCallbackDuringCapturing(std::function callback) { + capturing_graph_->AddResetCallback(std::move(callback)); + } + + // No need to add CUDA_VERSION macro because capturing_graph_ would + // always be nullptr (constructor throws error) + static bool IsCapturing() { return capturing_graph_ != nullptr; } + + static CUDAGraphID CapturingID() { return capturing_graph_->id_; } + + static platform::CUDAPlace CapturingPlace() { + return capturing_graph_->place_; + } + + private: +#if CUDA_VERSION >= 10010 + cudaGraph_t graph_{nullptr}; + cudaGraphExec_t exec_graph_{nullptr}; +#endif + cudaStream_t stream_{nullptr}; + platform::CUDAPlace place_; + CUDAGraphID id_{0}; + std::vector> callbacks_; + bool is_reset_{false}; + std::mutex mtx_; + + static std::unique_ptr capturing_graph_; +}; + +#if CUDA_VERSION >= 10010 +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard(cudaStreamCaptureMode mode) { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode)); + // After cudaThreadExchangeStreamCaptureMode is called, + // the variable "mode" would be set to the old capturing mode. + old_mode_ = mode; + } + } + + ~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaThreadExchangeStreamCaptureMode(&old_mode_)); + } + } + + private: + cudaStreamCaptureMode old_mode_; +}; +#else +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard(cudaStreamCaptureMode) {} +}; +#endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc new file mode 100644 index 0000000000..1f0d39e2ab --- /dev/null +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace platform { + +#ifdef PADDLE_WITH_CUDA +void BeginCUDAGraphCapture(platform::CUDAPlace place, + cudaStreamCaptureMode mode) { + auto stream = + platform::DeviceContextPool::Instance().GetByPlace(place)->stream(); + CUDAGraph::BeginCapture(place, stream, mode); + auto id = CUDAGraph::CapturingID(); + memory::allocation::AllocatorFacade::Instance().PrepareMemoryPoolForCUDAGraph( + id); + AddResetCallbackIfCapturingCUDAGraph([id] { + memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph( + id); + }); +} + +std::unique_ptr EndCUDAGraphCapture() { + return CUDAGraph::EndCapture(); +} +#endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.h b/paddle/fluid/platform/cuda_graph_with_memory_pool.h new file mode 100644 index 0000000000..f9f0248e51 --- /dev/null +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.h @@ -0,0 +1,64 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_graph.h" +#endif + +namespace paddle { +namespace platform { + +// NOTE: These APIs are not thread-safe. +#ifdef PADDLE_WITH_CUDA +void BeginCUDAGraphCapture(platform::CUDAPlace place, + cudaStreamCaptureMode mode); +std::unique_ptr EndCUDAGraphCapture(); +#endif + +inline bool IsCUDAGraphCapturing() { +#ifdef PADDLE_WITH_CUDA + return CUDAGraph::IsCapturing(); +#else + return false; +#endif +} + +inline platform::CUDAPlace CUDAGraphCapturingPlace() { +#ifdef PADDLE_WITH_CUDA + return CUDAGraph::CapturingPlace(); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "CUDA Graph is only supported on NVIDIA GPU device.")); +#endif +} + +// Add reset callback if CUDA Graph is capturing. +// Otherwise, invoke callback directly. +template +inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(IsCUDAGraphCapturing())) { + return CUDAGraph::AddResetCallbackDuringCapturing( + std::forward(callback)); + } +#endif + callback(); +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index c4ac5aa304..59e4404ffe 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -22,6 +22,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/dynload/miopen.h" #else +#include "paddle/fluid/platform/cuda_graph.h" #include "paddle/fluid/platform/dynload/cudnn.h" #endif #include "paddle/fluid/memory/malloc.h" @@ -557,6 +558,7 @@ class RecordedCudaMallocHelper { #ifdef PADDLE_WITH_HIP auto result = hipMalloc(ptr, size); #else + CUDAGraphCaptureModeGuard capture_mode_guard{cudaStreamCaptureModeRelaxed}; auto result = cudaMalloc(ptr, size); #endif if (result == gpuSuccess) { diff --git a/paddle/fluid/platform/type_defs.h b/paddle/fluid/platform/type_defs.h index f46bd1a0bd..88a2d16472 100644 --- a/paddle/fluid/platform/type_defs.h +++ b/paddle/fluid/platform/type_defs.h @@ -36,4 +36,5 @@ using gpuEvent_t = cudaEvent_t; using gpuDeviceProp = cudaDeviceProp; #endif +using CUDAGraphID = unsigned long long; // NOLINT } // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 22778013f2..875e6af965 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -7,7 +7,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator - cost_model) + cost_model cuda_graph_with_memory_pool) if (WITH_PSCORE) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a16916ab33..6b24c64492 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -125,6 +125,8 @@ limitations under the License. */ #include "paddle/fluid/platform/xpu/xpu_info.h" #endif +#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" + #ifdef PADDLE_WITH_CRYPTO #include "paddle/fluid/pybind/crypto.h" #endif @@ -520,6 +522,19 @@ PYBIND11_MODULE(core_noavx, m) { m.def("nccl_version", &GetNCCLVersion); #endif + m.def("is_cuda_graph_capturing", &platform::IsCUDAGraphCapturing); +#ifdef PADDLE_WITH_CUDA + py::class_(m, "CUDAGraph") + .def_static("begin_capture", + [](platform::CUDAPlace place, int mode) { + platform::BeginCUDAGraphCapture( + place, static_cast(mode)); + }) + .def_static("end_capture", &platform::EndCUDAGraphCapture) + .def("replay", &platform::CUDAGraph::Replay) + .def("reset", &platform::CUDAGraph::Reset); +#endif + m.def("wait_device", [](const platform::Place &place) { platform::DeviceContextPool::Instance().Get(place)->Wait(); }); diff --git a/python/paddle/device/cuda/graphs.py b/python/paddle/device/cuda/graphs.py new file mode 100644 index 0000000000..612f4d2c8c --- /dev/null +++ b/python/paddle/device/cuda/graphs.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace + +if is_compiled_with_cuda() and not is_compiled_with_rocm(): + from paddle.fluid.core import CUDAGraph as CoreCUDAGraph + + class CUDAGraph: + def __init__(self, place=None, mode="thread_local"): + ALL_MODES = ["global", "thread_local", "relaxed"] + self._graph = None + if place is None: + place = CUDAPlace(0) + self._place = place + assert mode in ALL_MODES + self._mode = ALL_MODES.index(mode) + + def capture_begin(self): + CoreCUDAGraph.begin_capture(self._place, self._mode) + + def capture_end(self): + self._graph = CoreCUDAGraph.end_capture() + + def replay(self): + self._graph.replay() + + def reset(self): + self._graph.reset() +else: + + class CUDAGraph: + def __init__(self, place=None, mode="thread_local"): + raise NotImplementedError() + + def capture_begin(self): + raise NotImplementedError() + + def capture_end(self): + raise NotImplementedError() + + def replay(self): + raise NotImplementedError() + + def reset(self): + raise NotImplementedError() diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph.py b/python/paddle/fluid/tests/unittests/test_cuda_graph.py new file mode 100644 index 0000000000..272d68e17f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +from paddle.device.cuda.graphs import CUDAGraph +import unittest +import numpy as np + + +class TestCUDAGraph(unittest.TestCase): + def setUp(self): + fluid.set_flags({'FLAGS_allocator_strategy': 'auto_growth'}) + + def random_tensor(self, shape): + return paddle.to_tensor( + np.random.randint( + low=0, high=10, size=shape).astype("float32")) + + def test_cuda_graph(self): + if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): + return + + shape = [2, 3] + x = self.random_tensor(shape) + z = self.random_tensor(shape) + + g = CUDAGraph() + g.capture_begin() + y = x + 10 + z.add_(x) + g.capture_end() + + for _ in range(10): + z_np_init = z.numpy() + x_new = self.random_tensor(shape) + x.copy_(x_new, False) + g.replay() + x_np = x_new.numpy() + y_np = y.numpy() + z_np = z.numpy() + self.assertTrue((y_np - x_np == 10).all()) + self.assertTrue((z_np - z_np_init == x_np).all()) + + g.reset() + + +if __name__ == "__main__": + unittest.main() -- GitLab