diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 6bd08767871cc3a48549b3c0feeeb6e2083d82fe..99e6af20bd8ea045be1207861d861bc6563cd994 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -39,7 +39,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #endif #if CUDA_VERSION >= 10020 @@ -157,7 +157,7 @@ class CUDAGraphAllocator static bool IsCUDAGraphCapturing() { #ifdef PADDLE_WITH_CUDA - return UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing()); + return UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()); #else return false; #endif @@ -1007,7 +1007,7 @@ AllocatorFacade& AllocatorFacade::Instance() { AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const { #ifdef PADDLE_WITH_CUDA if (UNLIKELY(IsCUDAGraphCapturing())) { - auto id = platform::CUDAGraph::CapturingPoolID(); + auto id = phi::backends::gpu::CUDAGraph::CapturingPoolID(); auto iter = cuda_graph_map_.find(id); PADDLE_ENFORCE_NE( iter, diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc index 903f50d9a4cb7c69f09c2078dc3acfff83a24711..9f513448eea26604931e023b12d67cac55a38174 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc @@ -19,7 +19,7 @@ #include "paddle/phi/backends/gpu/gpu_info.h" #ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #endif namespace paddle { @@ -49,7 +49,7 @@ void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) { std::lock_guard lock_guard(outstanding_event_map_lock_); #ifdef PADDLE_WITH_CUDA - if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) { + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { graph_capturing_stream_set_.insert(stream); return; } @@ -61,7 +61,7 @@ void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) { bool StreamSafeCUDAAllocation::CanBeFreed() { #ifdef PADDLE_WITH_CUDA - if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) { + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { return graph_capturing_stream_set_.empty() && outstanding_event_map_.empty(); } diff --git a/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu index 67f2df8cda5aaaa8d682211ebedca920a66da8c1..c4e87bef953d8f2529c59a479e521bdcbe313cf8 100644 --- a/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu +++ b/paddle/fluid/memory/stream_safe_cuda_alloc_test.cu @@ -319,7 +319,7 @@ class StreamSafeCUDAAllocTest : public ::testing::Test { data, result, data_num_); RecordStream(data_allocation, other_stream); - std::unique_ptr cuda_graph = + std::unique_ptr cuda_graph = platform::EndCUDAGraphCapture(); int replay_times = 10; diff --git a/paddle/fluid/operators/cuda_graph_with_in_out.h b/paddle/fluid/operators/cuda_graph_with_in_out.h index 40896c585c374eb76e754219adde024a5d154930..37be7bad1b3cb489367fdfe4072afd05a89a68be 100644 --- a/paddle/fluid/operators/cuda_graph_with_in_out.h +++ b/paddle/fluid/operators/cuda_graph_with_in_out.h @@ -89,7 +89,7 @@ class CUDAGraphWithInOuts { int64_t PoolID() const { return graph_->PoolID(); } private: - std::unique_ptr graph_; + std::unique_ptr graph_; std::vector ins_; std::vector outs_; std::vector in_indices_; diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 67d72af23aaa4dd543a6aed02a9a89f56383e37d..80f8ccde263d6a70a769b942d92dea766b4126f1 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/functors.h" @@ -206,7 +206,7 @@ class FMHARef { stride_b = gemm_k * gemm_n; if (dropout_param_.dropout_prob_) { - DropoutFwGPUKernelDriver( + phi::funcs::DropoutFwGPUKernelDriver( static_cast(dev_ctx_), dropout_param_.is_test_, dropout_param_.dropout_prob_, @@ -381,7 +381,7 @@ class FMHARef { stride_b = gemm_k * gemm_n; if (dropout_param_.dropout_prob_) { - DropoutFwGPUKernelDriver( + phi::funcs::DropoutFwGPUKernelDriver( static_cast(dev_ctx_), dropout_param_.is_test_, dropout_param_.dropout_prob_, @@ -552,7 +552,7 @@ class FMHARef { } // dropout bw if (dropout_param_.dropout_prob_) { - DropoutGradGPUKernelDriver( + phi::funcs::DropoutGradGPUKernelDriver( static_cast(dev_ctx_), false, dropout_param_.dropout_prob_, diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index f95d159144f3707b3bd8a7b234ee4fa8b42f2350..4ee01e058af1e68393c4f1133765dade01bdf6da 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -15,10 +15,10 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/operators/dropout_impl_util.h" #include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" +#include "paddle/phi/kernels/funcs/dropout_impl_util.h" #include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/layer_norm_kernel.h" @@ -106,7 +106,7 @@ struct DropoutParam { int UpdateSeedAndIncrement(const phi::GPUContext& ctx, const int offset) { uint64_t tmp_increment; - GetSeedDataAndIncrement( + phi::funcs::GetSeedDataAndIncrement( ctx, tensor_seed, fix_seed, seed_val, offset, &seed, &tmp_increment); increment = static_cast(tmp_increment); return increment; diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index 9f049b6e248f72152566f7afcf00dc5d7fc0766b..7a5acb762eb83bbc52254cc2427938a1c8f0ba39 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/backends/all_context.h" DECLARE_bool(use_stream_safe_cuda_allocator); @@ -23,10 +23,10 @@ namespace paddle { namespace platform { #ifdef PADDLE_WITH_CUDA -void BeginCUDAGraphCapture(platform::CUDAPlace place, +void BeginCUDAGraphCapture(phi::GPUPlace place, cudaStreamCaptureMode mode, int64_t pool_id) { - auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place); auto* dev_ctx = reinterpret_cast(mutable_dev_ctx); dev_ctx->cudnn_workspace_handle().ResetWorkspace(); @@ -64,7 +64,7 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place, std::unique_ptr EndCUDAGraphCapture() { auto place = CUDAGraph::CapturingPlace(); - auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place); + auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place); auto* dev_ctx = reinterpret_cast(mutable_dev_ctx); dev_ctx->cudnn_workspace_handle().ResetWorkspace(); dev_ctx->SetCUDAGraphAllocator(nullptr); diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.h b/paddle/fluid/platform/cuda_graph_with_memory_pool.h index 2ad72c8239c2a5603fb18196aa57aa8a9d161ad8..78f36a77e5f9cdbedc9f1dc8edfedf1def1b5edb 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.h +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.h @@ -14,123 +14,38 @@ #pragma once -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/place.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" -#endif +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/macros.h" namespace paddle { namespace platform { -#ifdef PADDLE_WITH_CUDA -#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ - __kernel_func, \ - __grid, \ - __block, \ - __sm_size, \ - __stream, \ - __seed_inc, \ - __seed_expr, \ - __offset_expr, \ - ...) \ - do { \ - if (::paddle::platform::CUDAGraph::IsThisThreadCapturing() && (__cond)) { \ - using __Helper = \ - ::paddle::platform::IsSameKernelHelper; \ - auto *dev_ctx = \ - ::paddle::platform::DeviceContextPool::Instance().GetByPlace( \ - ::paddle::platform::CUDAGraph::CapturingPlace()); \ - auto __set_seed_func = \ - [=](::paddle::platform::CUDAKernelParams *__params, \ - bool __check_only) -> bool { \ - if (__check_only) { \ - return __params->func() == &__kernel_func && \ - __Helper::Compare(*__params, __VA_ARGS__); \ - } \ - auto &KERNEL_PARAMS = *__params; \ - uint64_t __seed, __offset; \ - ::paddle::operators::GetSeedDataAndIncrement( \ - *dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \ - __seed_expr = static_cast(__seed); \ - __offset_expr = static_cast(__offset); \ - return true; \ - }; \ - ::paddle::platform::CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \ - } \ - __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ - } while (0) -#else -#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ - __kernel_func, \ - __grid, \ - __block, \ - __sm_size, \ - __stream, \ - __seed_inc, \ - __seed_expr, \ - __offset_expr, \ - ...) \ - do { \ - __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ - } while (0) -#endif - // NOTE: These APIs are not thread-safe. #ifdef PADDLE_WITH_CUDA -void BeginCUDAGraphCapture(platform::CUDAPlace place, +using CUDAGraph = phi::backends::gpu::CUDAGraph; + +void BeginCUDAGraphCapture(phi::GPUPlace place, cudaStreamCaptureMode mode, int64_t pool_id = CUDAGraph::kInvalidPoolID); std::unique_ptr EndCUDAGraphCapture(); #endif -inline bool IsCUDAGraphCapturing() { -#ifdef PADDLE_WITH_CUDA - return CUDAGraph::IsCapturing(); -#else - return false; -#endif -} - -inline platform::CUDAPlace CUDAGraphCapturingPlace() { +inline phi::GPUPlace CUDAGraphCapturingPlace() { #ifdef PADDLE_WITH_CUDA return CUDAGraph::CapturingPlace(); #else - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "CUDA Graph is only supported on NVIDIA GPU device.")); #endif } -// Add reset callback if CUDA Graph is capturing. -// Otherwise, invoke callback directly. -template -inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { -#ifdef PADDLE_WITH_CUDA - if (UNLIKELY(IsCUDAGraphCapturing())) { - return CUDAGraph::AddResetCallbackDuringCapturing( - std::forward(callback)); - } -#endif - callback(); -} +using phi::backends::gpu::IsCUDAGraphCapturing; -template -inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) { - static_assert(std::is_trivial::value, "T must be trivial type"); - static_assert(!std::is_same::value, "T cannot be void"); -#ifdef PADDLE_WITH_CUDA - if (UNLIKELY(IsCUDAGraphCapturing())) { - size_t nbytes = size * sizeof(T); - void *new_host_mem = new uint8_t[nbytes]; - std::memcpy(new_host_mem, host_mem, nbytes); - AddResetCallbackIfCapturingCUDAGraph( - [new_host_mem] { delete[] reinterpret_cast(new_host_mem); }); - return reinterpret_cast(new_host_mem); - } -#endif - return host_mem; -} +using phi::backends::gpu::AddResetCallbackIfCapturingCUDAGraph; + +using phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph; class SkipCUDAGraphCaptureGuard { DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard); diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h b/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h deleted file mode 100644 index 1c0843a0eb64578ebd311646e33d89f9efb83f2f..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_graph.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" - -namespace paddle { -namespace platform { - -using CUDAKernelParams = phi::backends::gpu::CUDAKernelParams; -#if CUDA_VERSION < 10010 -using cudaStreamCaptureMode = phi::backends::gpu::cudaStreamCaptureMode; -#endif -using CUDAGraph = phi::backends::gpu::CUDAGraph; -using CUDAGraphCaptureModeGuard = phi::backends::gpu::CUDAGraphCaptureModeGuard; - -template -static bool IsBitwiseEqual(const T &x, const T &y) { - return std::memcmp(&x, &y, sizeof(T)) == 0; -} - -template -struct IsSameKernelHelper; - -template -struct IsSameKernelHelper { - private: - using FuncArgsTuple = decltype(std::make_tuple(std::declval()...)); - - template - struct Impl { - static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { - using CompareT = typename std::tuple_element::type; - if (!IsBitwiseEqual(params.As(IDX), - std::get(args))) { - return false; - } - - constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size::value); - return Impl::Compare(params, args); - } - }; - - template - struct Impl { - static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { - return true; - } - }; - - public: - template - static bool Compare(const CUDAKernelParams ¶ms, Args... args) { - constexpr auto kNumArgs = sizeof...(FuncArgs); - static_assert(kNumArgs == sizeof...(Args), "Argument number not match"); - - auto args_tuple = std::make_tuple(args...); - using TupleT = typename std::decay::type; - return Impl::Compare(params, args_tuple); - } -}; - -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index 6952ce33a9318626fd2d9675b264bc77819c7a74..8023403df078d4878eec6687825c42ad5cb17d6a 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -36,8 +36,8 @@ limitations under the License. */ #ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/dynload/miopen.h" #else -#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" #include "paddle/fluid/platform/dynload/cudnn.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #endif #ifdef PADDLE_WITH_CUDA @@ -230,7 +230,7 @@ class RecordedGpuMallocHelper { result = hipMalloc(ptr, size); } #else - CUDAGraphCaptureModeGuard capture_mode_guard; + phi::backends::gpu::CUDAGraphCaptureModeGuard capture_mode_guard; if (UNLIKELY(malloc_managed_memory)) { result = cudaMallocManaged(ptr, size); } else { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f699b92e5045bc77486b82e61987b48a6e79c8c6..9ee0d3e4734e55c7970155b3093128e708a55e16 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -673,7 +673,7 @@ PYBIND11_MODULE(libpaddle, m) { m.def("is_cuda_graph_capturing", &platform::IsCUDAGraphCapturing); #ifdef PADDLE_WITH_CUDA - py::class_(m, "CUDAGraph") + py::class_(m, "CUDAGraph") .def_static("begin_capture", [](platform::CUDAPlace place, int mode) { platform::BeginCUDAGraphCapture( @@ -681,10 +681,11 @@ PYBIND11_MODULE(libpaddle, m) { }) .def_static("end_capture", &platform::EndCUDAGraphCapture) .def_static("gen_new_memory_pool_id", - &platform::CUDAGraph::UniqueMemoryPoolID) - .def("replay", &platform::CUDAGraph::Replay) - .def("reset", &platform::CUDAGraph::Reset) - .def("print_to_dot_files", &platform::CUDAGraph::PrintToDotFiles); + &phi::backends::gpu::CUDAGraph::UniqueMemoryPoolID) + .def("replay", &phi::backends::gpu::CUDAGraph::Replay) + .def("reset", &phi::backends::gpu::CUDAGraph::Reset) + .def("print_to_dot_files", + &phi::backends::gpu::CUDAGraph::PrintToDotFiles); #endif m.def("wait_device", [](const platform::Place &place) { diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index f2004eb6c7da0cac899a28b13394bf67dd96b3a5..13054c347ef4eb0222456e0e7eab75f8543c2af3 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -236,6 +236,54 @@ class CUDAGraphCaptureModeGuard { }; #endif +template +static bool IsBitwiseEqual(const T &x, const T &y) { + return std::memcmp(&x, &y, sizeof(T)) == 0; +} + +template +struct IsSameKernelHelper; + +template +struct IsSameKernelHelper { + private: + using FuncArgsTuple = decltype(std::make_tuple(std::declval()...)); + + template + struct Impl { + static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { + using CompareT = typename std::tuple_element::type; + if (!IsBitwiseEqual(params.As(IDX), + std::get(args))) { + return false; + } + + constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size::value); + return Impl::Compare(params, args); + } + }; + + template + struct Impl { + static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { + return true; + } + }; + + public: + template + static bool Compare(const CUDAKernelParams ¶ms, Args... args) { + constexpr auto kNumArgs = sizeof...(FuncArgs); + static_assert(kNumArgs == sizeof...(Args), "Argument number not match"); + + auto args_tuple = std::make_tuple(args...); + using TupleT = typename std::decay::type; + return Impl::Compare(params, args_tuple); + } +}; + } // namespace gpu } // namespace backends } // namespace phi diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..1d39d3faf13e442171632ade129e3389eda90356 --- /dev/null +++ b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h @@ -0,0 +1,125 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#include "paddle/phi/kernels/funcs/dropout_impl_util.h" +#endif + +namespace phi { +namespace backends { +namespace gpu { + +#ifdef PADDLE_WITH_CUDA +#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ + __kernel_func, \ + __grid, \ + __block, \ + __sm_size, \ + __stream, \ + __seed_inc, \ + __seed_expr, \ + __offset_expr, \ + ...) \ + do { \ + if (::phi::backends::gpu::CUDAGraph::IsThisThreadCapturing() && \ + (__cond)) { \ + using __Helper = \ + ::phi::backends::gpu::IsSameKernelHelper; \ + auto *dev_ctx = ::phi::DeviceContextPool::Instance().GetByPlace( \ + ::phi::backends::gpu::CUDAGraph::CapturingPlace()); \ + auto __set_seed_func = \ + [=](::phi::backends::gpu::CUDAKernelParams *__params, \ + bool __check_only) -> bool { \ + if (__check_only) { \ + return __params->func() == &__kernel_func && \ + __Helper::Compare(*__params, __VA_ARGS__); \ + } \ + auto &KERNEL_PARAMS = *__params; \ + uint64_t __seed, __offset; \ + ::phi::funcs::GetSeedDataAndIncrement( \ + *dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \ + __seed_expr = static_cast(__seed); \ + __offset_expr = static_cast(__offset); \ + return true; \ + }; \ + ::phi::backends::gpu::CUDAGraph::RecordRandomKernelInfo( \ + __set_seed_func); \ + } \ + __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ + } while (0) +#else +#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ + __kernel_func, \ + __grid, \ + __block, \ + __sm_size, \ + __stream, \ + __seed_inc, \ + __seed_expr, \ + __offset_expr, \ + ...) \ + do { \ + __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ + } while (0) +#endif + +inline bool IsCUDAGraphCapturing() { +#ifdef PADDLE_WITH_CUDA + return CUDAGraph::IsCapturing(); +#else + return false; +#endif +} + +// Add reset callback if CUDA Graph is capturing. +// Otherwise, invoke callback directly. +template +inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(IsCUDAGraphCapturing())) { + return CUDAGraph::AddResetCallbackDuringCapturing( + std::forward(callback)); + } +#endif + callback(); +} + +template +inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) { + static_assert(std::is_trivial::value, "T must be trivial type"); + static_assert(!std::is_same::value, "T cannot be void"); +#ifdef PADDLE_WITH_CUDA + if (UNLIKELY(IsCUDAGraphCapturing())) { + size_t nbytes = size * sizeof(T); + void *new_host_mem = new uint8_t[nbytes]; + std::memcpy(new_host_mem, host_mem, nbytes); + AddResetCallbackIfCapturingCUDAGraph( + [new_host_mem] { delete[] reinterpret_cast(new_host_mem); }); + return reinterpret_cast(new_host_mem); + } +#endif + return host_mem; +} + +} // namespace gpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu index 57cf64d8df16c68a272c781670cdffef5155215c..3c1f9ec6cc19a44a652f8de3399c92de3be3f7e4 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cu +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" namespace phi { namespace funcs { @@ -319,7 +319,7 @@ struct ConcatFunctor { context.GetPlace(), in_num * sizeof(T*), phi::Stream(reinterpret_cast(context.stream()))); - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( inputs_data, in_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_ins_data->ptr(), @@ -368,7 +368,7 @@ struct ConcatFunctor { inputs_col_num * sizeof(int64_t), phi::Stream(reinterpret_cast(context.stream()))); - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( inputs_col, inputs_col_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(), @@ -484,7 +484,7 @@ class SplitFunctor { context.GetPlace(), o_num * sizeof(T*), phi::Stream(reinterpret_cast(context.stream()))); - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( outputs_data, o_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_outs_data->ptr(), @@ -535,7 +535,7 @@ class SplitFunctor { context.GetPlace(), outputs_cols_num * sizeof(int64_t), phi::Stream(reinterpret_cast(context.stream()))); - auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph( + auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( outputs_cols, outputs_cols_num); paddle::memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(), diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h similarity index 94% rename from paddle/fluid/operators/dropout_impl.cu.h rename to paddle/phi/kernels/funcs/dropout_impl.cu.h index 413d02e3b67384c365719ffb0964cf90c0e2d6e7..40e6a7993606cd5bf1c0b5a5ffb0413ce55627f0 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -19,35 +19,29 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include #include - -#include "paddle/fluid/platform/dynload/curand.h" #endif #ifdef PADDLE_WITH_HIP #include #include - -#include "paddle/fluid/platform/dynload/hiprand.h" #endif -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/dropout_impl_util.h" -#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/phi/kernels/funcs/dropout_impl_util.h" + +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/primitive/compute_primitives.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { template struct DstMaskFunctor { const float retain_prob_; const bool is_upscale_in_train_; - using MT = typename details::MPTypeTrait::Type; + using MT = typename phi::kps::details::MPTypeTrait::Type; MT factor; HOSTDEVICE inline DstMaskFunctor(const float retain_prob, const bool is_upscale_in_train) @@ -149,7 +143,7 @@ __global__ void VectorizedRandomGenerator(const size_t n, template struct MaskFunctor { const float retain_prob_; - using MT = typename details::MPTypeTrait::Type; + using MT = typename phi::kps::details::MPTypeTrait::Type; MT factor; HOSTDEVICE inline MaskFunctor(const float retain_prob) : retain_prob_(retain_prob) { @@ -173,7 +167,7 @@ struct MaskFunctor { template struct DstFunctor { - using MT = typename details::MPTypeTrait::Type; + using MT = typename phi::kps::details::MPTypeTrait::Type; MT factor; HOSTDEVICE inline DstFunctor(const float retain_prob, const bool is_upscale_in_train, @@ -271,7 +265,7 @@ inline void CalcBroadcastedMask(const phi::GPUContext& dev_ctx, phi::DenseTensor* broadcasted_mask) { // The broadcast of mask can be combined to the following ElementwiseKernel // when the BroadcastKernel supports different input types. - broadcasted_mask->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(broadcasted_mask); std::vector ins = {&mask}; std::vector outs = {broadcasted_mask}; @@ -337,7 +331,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, size_t block_size = gpu_config.GetBlockSize(); int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); - const auto& prop = platform::GetDeviceProperties(device_id); + const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id); size_t max_grid_size = prop.maxThreadsPerMultiProcessor * prop.multiProcessorCount / block_size; grid_size = std::min(grid_size, max_grid_size); @@ -393,9 +387,9 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, } else { if (upscale_in_train) { // y = x - framework::TensorCopy(x, dev_ctx.GetPlace(), dev_ctx, y); + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, y); } else { - using MT = typename details::MPTypeTrait::Type; + using MT = typename phi::kps::details::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_prob); // y = factor * x ScaleByDropoutFactor(dev_ctx, x, y, factor); @@ -405,7 +399,7 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, template struct CudaDropoutGradFunctor { - using MT = typename details::MPTypeTrait::Type; + using MT = typename phi::kps::details::MPTypeTrait::Type; explicit CudaDropoutGradFunctor(const MT factor) : factor_(factor) {} @@ -428,7 +422,7 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, const phi::DenseTensor& mask, phi::DenseTensor* grad_x, bool is_dropout_nd = false) { - using MT = typename details::MPTypeTrait::Type; + using MT = typename phi::kps::details::MPTypeTrait::Type; auto stream = dev_ctx.stream(); if (is_test) { @@ -465,5 +459,5 @@ void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, } } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/phi/kernels/funcs/dropout_impl_util.h similarity index 78% rename from paddle/fluid/operators/dropout_impl_util.h rename to paddle/phi/kernels/funcs/dropout_impl_util.h index 84ff221cbe139c9f8b845b82a0dd3ef7397ef70d..ffb0d5dbd74d394ca2c2aed679941250094cff7d 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/phi/kernels/funcs/dropout_impl_util.h @@ -14,11 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/generator.h" +#include "paddle/phi/core/tensor_utils.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, const phi::DenseTensor* seed, @@ -27,13 +29,11 @@ inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, const int offset, uint64_t* seed_data, uint64_t* increment) { - int device_id = dev_ctx.GetPlace().GetDeviceId(); - auto gen_cuda = framework::DefaultCUDAGenerator(device_id); + auto gen_cuda = dev_ctx.GetGenerator(); if (seed) { phi::DenseTensor seed_cpu_tensor; - paddle::framework::TensorCopySync( - *seed, platform::CPUPlace(), &seed_cpu_tensor); + phi::Copy(dev_ctx, *seed, phi::CPUPlace(), true, &seed_cpu_tensor); *seed_data = static_cast(seed_cpu_tensor.data()[0]); *increment = offset; } else if (!is_fix_seed) { @@ -46,5 +46,5 @@ inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, } } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/dropout_grad_kernel.cu b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu index cdb8d0bd277622a4e116e055216daea327125917..d1a1cf8c27ab44b50e33548eba105bcb92c0c14f 100644 --- a/paddle/phi/kernels/gpu/dropout_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu @@ -14,9 +14,9 @@ #include "paddle/phi/kernels/dropout_grad_kernel.h" -#include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" namespace phi { @@ -30,14 +30,14 @@ void DropoutGradRawKernel(const Context& dev_ctx, DenseTensor* x_grad) { bool upscale_in_train = (mode == "upscale_in_train"); dev_ctx.template Alloc(x_grad); - paddle::operators::DropoutGradGPUKernelDriver(dev_ctx, - is_test, - p.to(), - upscale_in_train, - out_grad, - mask, - x_grad, - false); + phi::funcs::DropoutGradGPUKernelDriver(dev_ctx, + is_test, + p.to(), + upscale_in_train, + out_grad, + mask, + x_grad, + false); } template @@ -51,14 +51,14 @@ void DropoutNdGradKernel(const Context& dev_ctx, DenseTensor* x_grad) { bool upscale_in_train = (mode == "upscale_in_train"); dev_ctx.template Alloc(x_grad); - paddle::operators::DropoutGradGPUKernelDriver(dev_ctx, - is_test, - p.to(), - upscale_in_train, - out_grad, - mask, - x_grad, - true); + phi::funcs::DropoutGradGPUKernelDriver(dev_ctx, + is_test, + p.to(), + upscale_in_train, + out_grad, + mask, + x_grad, + true); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/dropout_kernel.cu b/paddle/phi/kernels/gpu/dropout_kernel.cu index 34707bfd665f9043fedf85f4e27add49df840cb7..0e6e67faec39510a8cfb75672ce3d7bed0ab82f8 100644 --- a/paddle/phi/kernels/gpu/dropout_kernel.cu +++ b/paddle/phi/kernels/gpu/dropout_kernel.cu @@ -14,9 +14,9 @@ #include "paddle/phi/kernels/dropout_kernel.h" -#include "paddle/fluid/operators/dropout_impl.cu.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" namespace phi { @@ -36,17 +36,17 @@ void DropoutRawKernel(const Context& dev_ctx, if (mask) { dev_ctx.template Alloc(mask); } - paddle::operators::DropoutFwGPUKernelDriver(dev_ctx, - is_test, - p.to(), - upscale_in_train, - fix_seed, - seed, - x, - seed_tensor.get_ptr(), - mask, - out, - false); + phi::funcs::DropoutFwGPUKernelDriver(dev_ctx, + is_test, + p.to(), + upscale_in_train, + fix_seed, + seed, + x, + seed_tensor.get_ptr(), + mask, + out, + false); } template @@ -66,17 +66,17 @@ void DropoutNdKernel(const Context& dev_ctx, if (mask) { dev_ctx.template Alloc(mask); } - paddle::operators::DropoutFwGPUKernelDriver(dev_ctx, - is_test, - p.to(), - upscale_in_train, - fix_seed, - seed, - x, - seed_tensor.get_ptr(), - mask, - out, - true); + phi::funcs::DropoutFwGPUKernelDriver(dev_ctx, + is_test, + p.to(), + upscale_in_train, + fix_seed, + seed, + x, + seed_tensor.get_ptr(), + mask, + out, + true); } } // namespace phi diff --git a/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h b/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h index cc32759b5f044625a8464f9491baf5e155f01c51..852b0d77e5fe922438f19833f02290fc709ec98f 100644 --- a/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h +++ b/paddle/phi/kernels/gpudnn/conv_cudnn_v7.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" #include "paddle/phi/kernels/gpudnn/conv_gpudnn_base.h"