diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 20a922b406745877df15eb79d3052381937b6a15..f7c57fa2b02d6c6ac9036636facd977ef2e3a82a 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -32,7 +32,7 @@ if(WITH_GPU OR WITH_ROCM) endif() if(WITH_GPU) - list(APPEND ALLOCATOR_DEPS cuda_graph) + list(APPEND ALLOCATOR_DEPS phi_backends) endif() if(CUDA_VERSION VERSION_GREATER_EQUAL 10.2) diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 8c93eaf2469d1126caad76f6c22e3e0eecfef81f..2db144f423fc7cbfb0e2144cb349aad3875284cd 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -85,7 +85,7 @@ if(WITH_GPU) nv_library( cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc - DEPS device_context allocator cuda_graph) + DEPS device_context allocator phi_backends) else() cc_library( cuda_graph_with_memory_pool diff --git a/paddle/fluid/platform/device/gpu/cuda/CMakeLists.txt b/paddle/fluid/platform/device/gpu/cuda/CMakeLists.txt index 64a2f891c21cdff453a865feccbfdd64b7a5dd8b..07901054b3b3370d15e4e42f5645010a5289016f 100644 --- a/paddle/fluid/platform/device/gpu/cuda/CMakeLists.txt +++ b/paddle/fluid/platform/device/gpu/cuda/CMakeLists.txt @@ -1,7 +1,3 @@ -nv_library( - cuda_graph - SRCS cuda_graph.cc - DEPS enforce) nv_library( cuda_profiler SRCS cuda_profiler.cc diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h index 5b5151ea822e820f55ff06954cff9a25ade12f7f..1c0843a0eb64578ebd311646e33d89f9efb83f2f 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h @@ -14,45 +14,23 @@ #pragma once -#include -#include -#include -#include -#include -#include - -#include "cuda.h" // NOLINT -#include "cuda_runtime.h" // NOLINT -#include "paddle/fluid/platform/device/gpu/gpu_types.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/utils/optional.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" namespace paddle { namespace platform { +using CUDAKernelParams = phi::backends::gpu::CUDAKernelParams; +#if CUDA_VERSION < 10010 +using cudaStreamCaptureMode = phi::backends::gpu::cudaStreamCaptureMode; +#endif +using CUDAGraph = phi::backends::gpu::CUDAGraph; +using CUDAGraphCaptureModeGuard = phi::backends::gpu::CUDAGraphCaptureModeGuard; + template static bool IsBitwiseEqual(const T &x, const T &y) { return std::memcmp(&x, &y, sizeof(T)) == 0; } -class CUDAKernelParams { - public: - explicit CUDAKernelParams(const cudaKernelNodeParams *params) - : params_(params) {} - - const void *func() const { return params_->func; } - - template - T &As(size_t idx) const { - return *reinterpret_cast(params_->kernelParams[idx]); - } - - private: - const cudaKernelNodeParams *params_; -}; - template struct IsSameKernelHelper; @@ -96,191 +74,5 @@ struct IsSameKernelHelper { } }; -#if CUDA_VERSION >= 10010 -static void ThrowErrorIfNotSupportCUDAGraph() {} -#else -enum cudaStreamCaptureMode { - cudaStreamCaptureModeGlobal = 0, - cudaStreamCaptureModeThreadLocal = 1, - cudaStreamCaptureModeRelaxed = 2 -}; -static void ThrowErrorIfNotSupportCUDAGraph() { - PADDLE_THROW(platform::errors::Unimplemented( - "CUDA Graph is only supported when CUDA version >= 10.1")); -} -#endif - -// NOTE: Currently, we do not support to capture CUDA graph in parallel -// NOTE: Do not use this class directly because it should be used with -// the memory pool. -class CUDAGraph { - DISABLE_COPY_AND_ASSIGN(CUDAGraph); - - // Since the constructor would throw error is CUDA_VERSION < 10010. - // The non-static method of CUDAGraph need not check CUDA_VERSION - // again. - CUDAGraph() { - ThrowErrorIfNotSupportCUDAGraph(); - id_ = UniqueID(); - } - - public: - static constexpr int64_t kDefaultPoolID = 0; - static constexpr int64_t kInvalidPoolID = -1; - - ~CUDAGraph() { Reset(); } - - CUDAGraphID ID() const { return id_; } - - static int64_t SetMemoryPoolID(int64_t pool_id) { - auto &pool_id_ = capturing_graph_->pool_id_; - PADDLE_ENFORCE_EQ( - pool_id_, - kInvalidPoolID, - phi::errors::InvalidArgument("Cannot reset memory pool id twice, the " - "former memory pool id is %d.", - pool_id_)); - if (pool_id <= kInvalidPoolID) { - pool_id_ = UniqueMemoryPoolID(); - } else { - PADDLE_ENFORCE_GE( - pool_id, - kDefaultPoolID, - phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id)); - pool_id_ = pool_id; - } - return pool_id_; - } - - int64_t PoolID() const { return pool_id_; } - - static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; } - - void Replay(); - - void Reset(); - - void AddResetCallback(std::function callback) { - std::lock_guard guard(mtx_); - callbacks_.push_back(std::move(callback)); - } - - void PrintToDotFiles(const std::string &dirname, unsigned int flags); - - static void BeginCapture(platform::CUDAPlace place, - cudaStream_t stream, - cudaStreamCaptureMode mode); - static std::unique_ptr EndCapture(); - - static void BeginSegmentCapture(); - static void EndSegmentCapture(); - - static void AddResetCallbackDuringCapturing(std::function callback) { - capturing_graph_->AddResetCallback(std::move(callback)); - } - - // No need to add CUDA_VERSION macro because capturing_graph_ would - // always be nullptr (constructor throws error) - static bool IsCapturing() { return capturing_graph_ != nullptr; } - - static CUDAGraphID CapturingID() { return capturing_graph_->id_; } - - static platform::CUDAPlace CapturingPlace() { - return capturing_graph_->place_; - } - - // This API can be used to debug which GPU operation is not - // supported during capturing CUDA Graph. - static bool IsValidCapturing(); - - static bool IsThreadLocalCapturing() { -#if CUDA_VERSION >= 10010 - return IsCapturing() && - capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal; -#else - return false; -#endif - } - - static bool IsThisThreadCapturing() { - if (UNLIKELY(IsCapturing())) { - return IsThreadLocalCapturing() - ? capturing_thread_id_.get() == std::this_thread::get_id() - : true; - } else { - return false; - } - } - - using SetSeedFunc = std::function; - static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) { - std::lock_guard guard(capturing_graph_->func_mtx_); - capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func)); - } - - static int64_t UniqueMemoryPoolID(); - - private: - static CUDAGraphID UniqueID(); - - private: -#if CUDA_VERSION >= 10010 - std::vector graphs_; - std::vector exec_graphs_; - cudaStreamCaptureMode capture_mode_; -#endif - cudaStream_t stream_{nullptr}; - platform::CUDAPlace place_; - CUDAGraphID id_; - int64_t pool_id_{kInvalidPoolID}; - std::vector> callbacks_; - bool is_reset_{false}; - std::mutex mtx_; - - std::vector set_seed_funcs_; - std::vector>> pre_hooks_; - std::mutex func_mtx_; - - bool is_first_run_{true}; - - static paddle::optional capturing_thread_id_; - static std::unique_ptr capturing_graph_; -}; - -#if CUDA_VERSION >= 10010 -class CUDAGraphCaptureModeGuard { - DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); - - public: - explicit CUDAGraphCaptureModeGuard( - cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) { - if (UNLIKELY(CUDAGraph::IsCapturing())) { - PADDLE_ENFORCE_GPU_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode)); - // After cudaThreadExchangeStreamCaptureMode is called, - // the variable "mode" would be set to the old capturing mode. - old_mode_ = mode; - } - } - - ~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW { - if (UNLIKELY(CUDAGraph::IsCapturing())) { - PADDLE_ENFORCE_GPU_SUCCESS( - cudaThreadExchangeStreamCaptureMode(&old_mode_)); - } - } - - private: - cudaStreamCaptureMode old_mode_; -}; -#else -class CUDAGraphCaptureModeGuard { - DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); - - public: - explicit CUDAGraphCaptureModeGuard( - cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {} -}; -#endif - } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device/xpu/CMakeLists.txt b/paddle/fluid/platform/device/xpu/CMakeLists.txt index 31ac51050b87bdb2f27886061c5583b48687416f..242f2a8e26002e12a939c28fba489afa9049abb2 100644 --- a/paddle/fluid/platform/device/xpu/CMakeLists.txt +++ b/paddle/fluid/platform/device/xpu/CMakeLists.txt @@ -30,7 +30,7 @@ cc_library( xpulib device_context op_kernel_type - phi_xpu_op_list) + phi_backends) cc_library( xpu_resource_pool SRCS xpu_resource_pool.cc diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index a4c76ab0e68a0cd3fc8e6826ecbae669c2139c89..f8a6b2174a830acf212e26bb289204ede08943e6 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -7,7 +7,7 @@ if(WITH_GPU OR WITH_ROCM) list(APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc gpu/gpu_resources.cc) if(WITH_GPU) - list(APPEND BACKENDS_SRCS gpu/cuda/cuda_info.cc) + list(APPEND BACKENDS_SRCS gpu/cuda/cuda_info.cc gpu/cuda/cuda_graph.cc) endif() if(WITH_ROCM) list(APPEND BACKENDS_SRCS gpu/rocm/rocm_info.cc) @@ -16,8 +16,9 @@ if(WITH_GPU OR WITH_ROCM) endif() if(WITH_XPU) - add_subdirectory(xpu) list(APPEND BACKENDS_SRCS xpu/xpu_context.cc xpu/xpu_info.cc) + list(APPEND BACKENDS_SRCS xpu/xpu_op_list.cc xpu/xpu1_op_list.cc + xpu/xpu2_op_list.cc) endif() if(WITH_MKLDNN) diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc b/paddle/phi/backends/gpu/cuda/cuda_graph.cc similarity index 90% rename from paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc rename to paddle/phi/backends/gpu/cuda/cuda_graph.cc index 61c8fe4f4c5fd05fdc593395a4b503e68bd7b362..5fc39a5319945793bdf6c8c5c2a53bfde1434866 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #include #include #include -namespace paddle { -namespace platform { +namespace phi { +namespace backends { +namespace gpu { std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; paddle::optional CUDAGraph::capturing_thread_id_{paddle::none}; @@ -113,7 +114,7 @@ void CUDAGraph::Replay() { #if CUDA_VERSION >= 10010 PADDLE_ENFORCE_EQ(is_reset_, false, - errors::PermissionDenied( + phi::errors::PermissionDenied( "Cannot replay the CUDA Graph after reset is called.")); size_t n = exec_graphs_.size(); for (size_t i = 0; i < n; ++i) { @@ -131,43 +132,43 @@ void CUDAGraph::Replay() { void CUDAGraph::BeginSegmentCapture() { ThrowErrorIfNotSupportCUDAGraph(); #if CUDA_VERSION >= 10010 - PADDLE_ENFORCE_EQ( - IsCapturing(), - true, - errors::PermissionDenied("BeginSegmentCapture should be called when CUDA " - "Graph is capturing.")); + PADDLE_ENFORCE_EQ(IsCapturing(), + true, + phi::errors::PermissionDenied( + "BeginSegmentCapture should be called when CUDA " + "Graph is capturing.")); if (IsThreadLocalCapturing()) { PADDLE_ENFORCE_EQ(IsThisThreadCapturing(), true, - platform::errors::PermissionDenied( + phi::errors::PermissionDenied( "When capturing CUDA Graph in the thread local mode, " "you cannot begin segmented capturing in the thread " "which is not the one that starts the capturing.")); } PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture( capturing_graph_->stream_, capturing_graph_->capture_mode_)); - PADDLE_ENFORCE_EQ(IsValidCapturing(), - true, - platform::errors::PermissionDenied( - "CUDA Graph should not be invalidated.")); + PADDLE_ENFORCE_EQ( + IsValidCapturing(), + true, + phi::errors::PermissionDenied("CUDA Graph should not be invalidated.")); VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_ << ", segment id " << capturing_graph_->graphs_.size() << ", memory pool id " << capturing_graph_->pool_id_; #endif } -void CUDAGraph::BeginCapture(platform::CUDAPlace place, +void CUDAGraph::BeginCapture(phi::GPUPlace place, cudaStream_t stream, cudaStreamCaptureMode mode) { ThrowErrorIfNotSupportCUDAGraph(); #if CUDA_VERSION >= 10010 - PADDLE_ENFORCE_EQ( - IsCapturing(), - false, - errors::PermissionDenied("CUDA Graph can only captured one by one.")); + PADDLE_ENFORCE_EQ(IsCapturing(), + false, + phi::errors::PermissionDenied( + "CUDA Graph can only captured one by one.")); PADDLE_ENFORCE_NOT_NULL( stream, - errors::PermissionDenied( + phi::errors::PermissionDenied( "CUDA Graph cannot be captured in default CUDA stream 0.")); capturing_graph_.reset(new CUDAGraph()); capturing_graph_->place_ = place; @@ -185,9 +186,10 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place, void CUDAGraph::EndSegmentCapture() { ThrowErrorIfNotSupportCUDAGraph(); #if CUDA_VERSION >= 10010 - PADDLE_ENFORCE_EQ(IsCapturing(), - true, - errors::PermissionDenied("No CUDA Graph is capturing.")); + PADDLE_ENFORCE_EQ( + IsCapturing(), + true, + phi::errors::PermissionDenied("No CUDA Graph is capturing.")); cudaGraph_t graph; PADDLE_ENFORCE_GPU_SUCCESS( cudaStreamEndCapture(capturing_graph_->stream_, &graph)); @@ -299,11 +301,12 @@ void CUDAGraph::PrintToDotFiles(const std::string &dirname, cudaGraphDebugDotPrint(graphs_[i], filename.c_str(), flags)); } #else - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "The print_to_dot_files() method is only supported when CUDA version >= " "11.3.")); #endif } -} // namespace platform -} // namespace paddle +} // namespace gpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..f2004eb6c7da0cac899a28b13394bf67dd96b3a5 --- /dev/null +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -0,0 +1,241 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "cuda.h" // NOLINT +#include "cuda_runtime.h" // NOLINT + +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/macros.h" +#include "paddle/utils/optional.h" + +namespace phi { +namespace backends { +namespace gpu { + +class CUDAKernelParams { + public: + explicit CUDAKernelParams(const cudaKernelNodeParams *params) + : params_(params) {} + + const void *func() const { return params_->func; } + + template + T &As(size_t idx) const { + return *reinterpret_cast(params_->kernelParams[idx]); + } + + private: + const cudaKernelNodeParams *params_; +}; + +#if CUDA_VERSION >= 10010 +static void ThrowErrorIfNotSupportCUDAGraph() {} +#else +enum cudaStreamCaptureMode { + cudaStreamCaptureModeGlobal = 0, + cudaStreamCaptureModeThreadLocal = 1, + cudaStreamCaptureModeRelaxed = 2 +}; +static void ThrowErrorIfNotSupportCUDAGraph() { + PADDLE_THROW(phi::errors::Unimplemented( + "CUDA Graph is only supported when CUDA version >= 10.1")); +} +#endif + +using CUDAGraphID = unsigned long long; // NOLINT + +// NOTE: Currently, we do not support to capture CUDA graph in parallel +// NOTE: Do not use this class directly because it should be used with +// the memory pool. +class CUDAGraph { + DISABLE_COPY_AND_ASSIGN(CUDAGraph); + + // Since the constructor would throw error is CUDA_VERSION < 10010. + // The non-static method of CUDAGraph need not check CUDA_VERSION + // again. + CUDAGraph() { + ThrowErrorIfNotSupportCUDAGraph(); + id_ = UniqueID(); + } + + public: + static constexpr int64_t kDefaultPoolID = 0; + static constexpr int64_t kInvalidPoolID = -1; + + ~CUDAGraph() { Reset(); } + + CUDAGraphID ID() const { return id_; } + + static int64_t SetMemoryPoolID(int64_t pool_id) { + auto &pool_id_ = capturing_graph_->pool_id_; + PADDLE_ENFORCE_EQ( + pool_id_, + kInvalidPoolID, + phi::errors::InvalidArgument("Cannot reset memory pool id twice, the " + "former memory pool id is %d.", + pool_id_)); + if (pool_id <= kInvalidPoolID) { + pool_id_ = UniqueMemoryPoolID(); + } else { + PADDLE_ENFORCE_GE( + pool_id, + kDefaultPoolID, + phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id)); + pool_id_ = pool_id; + } + return pool_id_; + } + + int64_t PoolID() const { return pool_id_; } + + static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; } + + void Replay(); + + void Reset(); + + void AddResetCallback(std::function callback) { + std::lock_guard guard(mtx_); + callbacks_.push_back(std::move(callback)); + } + + void PrintToDotFiles(const std::string &dirname, unsigned int flags); + + static void BeginCapture(phi::GPUPlace place, + cudaStream_t stream, + cudaStreamCaptureMode mode); + static std::unique_ptr EndCapture(); + + static void BeginSegmentCapture(); + static void EndSegmentCapture(); + + static void AddResetCallbackDuringCapturing(std::function callback) { + capturing_graph_->AddResetCallback(std::move(callback)); + } + + // No need to add CUDA_VERSION macro because capturing_graph_ would + // always be nullptr (constructor throws error) + static bool IsCapturing() { return capturing_graph_ != nullptr; } + + static CUDAGraphID CapturingID() { return capturing_graph_->id_; } + + static phi::GPUPlace CapturingPlace() { return capturing_graph_->place_; } + + // This API can be used to debug which GPU operation is not + // supported during capturing CUDA Graph. + static bool IsValidCapturing(); + + static bool IsThreadLocalCapturing() { +#if CUDA_VERSION >= 10010 + return IsCapturing() && + capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal; +#else + return false; +#endif + } + + static bool IsThisThreadCapturing() { + if (UNLIKELY(IsCapturing())) { + return IsThreadLocalCapturing() + ? capturing_thread_id_.get() == std::this_thread::get_id() + : true; + } else { + return false; + } + } + + using SetSeedFunc = std::function; + static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) { + std::lock_guard guard(capturing_graph_->func_mtx_); + capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func)); + } + + static int64_t UniqueMemoryPoolID(); + + private: + static CUDAGraphID UniqueID(); + + private: +#if CUDA_VERSION >= 10010 + std::vector graphs_; + std::vector exec_graphs_; + cudaStreamCaptureMode capture_mode_; +#endif + cudaStream_t stream_{nullptr}; + phi::GPUPlace place_; + CUDAGraphID id_; + int64_t pool_id_{kInvalidPoolID}; + std::vector> callbacks_; + bool is_reset_{false}; + std::mutex mtx_; + + std::vector set_seed_funcs_; + std::vector>> pre_hooks_; + std::mutex func_mtx_; + + bool is_first_run_{true}; + + static paddle::optional capturing_thread_id_; + static std::unique_ptr capturing_graph_; +}; + +#if CUDA_VERSION >= 10010 +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard( + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode)); + // After cudaThreadExchangeStreamCaptureMode is called, + // the variable "mode" would be set to the old capturing mode. + old_mode_ = mode; + } + } + + ~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_GPU_SUCCESS( + cudaThreadExchangeStreamCaptureMode(&old_mode_)); + } + } + + private: + cudaStreamCaptureMode old_mode_; +}; +#else +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard( + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {} +}; +#endif + +} // namespace gpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/backends/xpu/CMakeLists.txt b/paddle/phi/backends/xpu/CMakeLists.txt deleted file mode 100644 index d84e6a63e058a3735fc06d5a6ac9cde6907e06bb..0000000000000000000000000000000000000000 --- a/paddle/phi/backends/xpu/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -cc_library( - phi_xpu_op_list - SRCS xpu_op_list.cc xpu1_op_list.cc xpu2_op_list.cc - DEPS glog) diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 8911da82b5480e478625462572f1abbac8851710..6dc43ff633f1910d89c70d05b23f22260fc32be5 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -19,7 +19,7 @@ if(WITH_XPU) cc_library( kernel_factory SRCS kernel_factory.cc - DEPS phi_enforce convert_utils phi_xpu_op_list) + DEPS phi_enforce convert_utils phi_backends) else() cc_library( kernel_factory diff --git a/paddle/phi/core/device_context.cc b/paddle/phi/core/device_context.cc index a18e695cce4d8d6d0cbe1b5620d2f096a53a0b13..60747e36185a524af5dd1d1894daf95e6bcf810b 100644 --- a/paddle/phi/core/device_context.cc +++ b/paddle/phi/core/device_context.cc @@ -15,7 +15,7 @@ #include "paddle/phi/core/device_context.h" #ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #endif #include "paddle/phi/core/dense_tensor.h" @@ -153,8 +153,9 @@ struct DeviceContext::Impl { : (pinned ? pinned_allocator_ : device_allocator_); #ifdef PADDLE_WITH_CUDA bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned; - if (must_cuda_graph_allocator && paddle::platform::is_gpu_place(place) && - paddle::platform::CUDAGraph::IsThisThreadCapturing()) { + if (must_cuda_graph_allocator && + place.GetType() == phi::AllocationType::GPU && + phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_, phi::errors::InvalidArgument( "Required cuda_graph_allocator_ shall not be " diff --git a/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h b/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h index ac4a60384af1942394cfaa5f91500e04314c2a6c..cc32759b5f044625a8464f9491baf5e155f01c51 100644 --- a/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h +++ b/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h @@ -465,7 +465,7 @@ struct SearchAlgorithmBase { static size_t GetWorkspaceSize(const ConvArgs& args, cudnnConvolutionBwdFilterAlgo_t algo) { - paddle::platform::CUDAGraphCaptureModeGuard guard; + phi::backends::gpu::CUDAGraphCaptureModeGuard guard; size_t workspace_size = 0; PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(