// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "cuda.h" // NOLINT #include "cuda_runtime.h" // NOLINT #include "paddle/fluid/platform/type_defs.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" namespace paddle { namespace platform { #if CUDA_VERSION >= 10010 static void ThrowErrorIfNotSupportCUDAGraph() {} #else enum cudaStreamCaptureMode { cudaStreamCaptureModeGlobal = 0, cudaStreamCaptureModeThreadLocal = 1, cudaStreamCaptureModeRelaxed = 2 }; static void ThrowErrorIfNotSupportCUDAGraph() { PADDLE_THROW(platform::errors::Unimplemented( "CUDA Graph is only supported when CUDA version >= 10.1")); } #endif // NOTE: Currently, we do not support to capture CUDA graph in parallel // NOTE: Do not use this class directly because it should be used with // the memory pool. class CUDAGraph { DISABLE_COPY_AND_ASSIGN(CUDAGraph); // Since the constructor would throw error is CUDA_VERSION < 10010. // The non-static method of CUDAGraph need not check CUDA_VERSION // again. CUDAGraph() { ThrowErrorIfNotSupportCUDAGraph(); } public: ~CUDAGraph() { Reset(); } CUDAGraphID ID() const { return id_; } void Replay(); void Reset(); void AddResetCallback(std::function callback) { std::lock_guard guard(mtx_); callbacks_.push_back(std::move(callback)); } static void BeginCapture(platform::CUDAPlace place, cudaStream_t stream, cudaStreamCaptureMode mode); static std::unique_ptr EndCapture(); static void AddResetCallbackDuringCapturing(std::function callback) { capturing_graph_->AddResetCallback(std::move(callback)); } // No need to add CUDA_VERSION macro because capturing_graph_ would // always be nullptr (constructor throws error) static bool IsCapturing() { return capturing_graph_ != nullptr; } static CUDAGraphID CapturingID() { return capturing_graph_->id_; } static platform::CUDAPlace CapturingPlace() { return capturing_graph_->place_; } // This API can be used to debug which GPU operation is not // supported during capturing CUDA Graph. static bool IsValidCapturing(); private: #if CUDA_VERSION >= 10010 cudaGraph_t graph_{nullptr}; cudaGraphExec_t exec_graph_{nullptr}; #endif cudaStream_t stream_{nullptr}; platform::CUDAPlace place_; CUDAGraphID id_{0}; std::vector> callbacks_; bool is_reset_{false}; std::mutex mtx_; static std::unique_ptr capturing_graph_; }; #if CUDA_VERSION >= 10010 class CUDAGraphCaptureModeGuard { DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); public: explicit CUDAGraphCaptureModeGuard( cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) { if (UNLIKELY(CUDAGraph::IsCapturing())) { PADDLE_ENFORCE_CUDA_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode)); // After cudaThreadExchangeStreamCaptureMode is called, // the variable "mode" would be set to the old capturing mode. old_mode_ = mode; } } ~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW { if (UNLIKELY(CUDAGraph::IsCapturing())) { PADDLE_ENFORCE_CUDA_SUCCESS( cudaThreadExchangeStreamCaptureMode(&old_mode_)); } } private: cudaStreamCaptureMode old_mode_; }; #else class CUDAGraphCaptureModeGuard { DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); public: explicit CUDAGraphCaptureModeGuard( cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {} }; #endif } // namespace platform } // namespace paddle