From 27ee6e714046e8cf6dd913854da167233f7f7c41 Mon Sep 17 00:00:00 2001 From: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:25:18 +0800 Subject: [PATCH] [PHI decoupling] move "gpu_device_function.h" from fluid to phi (#48097) * move "paddle/phi/backends/gpu/gpu_device_function.h" to phi * update copyright years * rm "fluid/platform/device/gpu/gpu_device_function.h" in phi * fix rocm-complie bugs --- .../backends/gpu/cuda/cuda_device_function.h | 191 ++++++++++++++++++ paddle/phi/backends/gpu/gpu_device_function.h | 24 +++ .../backends/gpu/rocm/rocm_device_function.h | 165 +++++++++++++++ .../phi/kernels/funcs/elementwise_grad_base.h | 36 ++-- paddle/phi/kernels/funcs/reduce_function.h | 2 +- .../phi/kernels/gpu/activation_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/activation_kernel.cu | 2 +- .../kernels/gpu/affine_grid_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/affine_grid_kernel.cu | 2 +- .../kernels/gpu/cross_entropy_grad_kernel.cu | 2 +- .../phi/kernels/gpu/cross_entropy_kernel.cu | 2 +- paddle/phi/kernels/gpu/depthwise_conv.h | 4 +- .../kernels/gpu/grid_sample_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/group_norm_utils.h | 2 +- paddle/phi/kernels/gpu/interpolate_kernel.cu | 2 +- .../kernels/gpudnn/affine_grid_grad_kernel.cu | 2 +- .../phi/kernels/gpudnn/affine_grid_kernel.cu | 2 +- paddle/phi/kernels/gpudnn/softmax_gpudnn.h | 6 +- .../kernels/primitive/compute_primitives.h | 6 +- 19 files changed, 420 insertions(+), 36 deletions(-) create mode 100644 paddle/phi/backends/gpu/cuda/cuda_device_function.h create mode 100644 paddle/phi/backends/gpu/gpu_device_function.h create mode 100644 paddle/phi/backends/gpu/rocm/rocm_device_function.h diff --git a/paddle/phi/backends/gpu/cuda/cuda_device_function.h b/paddle/phi/backends/gpu/cuda/cuda_device_function.h new file mode 100644 index 00000000000..10aee53c45c --- /dev/null +++ b/paddle/phi/backends/gpu/cuda/cuda_device_function.h @@ -0,0 +1,191 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// NOTE(): support float16 to half in header file. +#define PADDLE_CUDA_FP16 +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace backends { +namespace gpu { + +#define FULL_WARP_MASK 0xFFFFFFFF +#define CREATE_SHFL_MASK(mask, predicate) \ + mask = __ballot_sync(FULL_WARP_MASK, (predicate)) + +#define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kPowerOfTwoDim = (dim); \ + __VA_ARGS__; \ + } break + +#define CUDA_LAUNCH_KERNEL_HELPER(...) \ + CUDA_LAUNCH_KERNEL_BASE(1024, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(512, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__); + +template +__forceinline__ __device__ T +CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { + return __shfl_down_sync(mask, val, static_cast(delta), width); +} + +template +__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, + T val, + int width = warpSize) { + return __shfl_xor_sync(mask, val, width); +} + +template <> +__forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync( + unsigned mask, phi::dtype::float16 val, int delta, int width) { + return phi::dtype::float16(__shfl_down_sync( + mask, val.to_half(), static_cast(delta), width)); +} + +template <> +__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( + unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { +#if defined(PADDLE_CUDA_BF16) + return phi::dtype::bfloat16(__shfl_down_sync(mask, + static_cast(val), + static_cast(delta), + width)); +#else + PADDLE_ENFORCE( + false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11."); +#endif +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( + unsigned mask, phi::dtype::complex val, int delta, int width) { + float real = static_cast(__shfl_down_sync( + mask, static_cast(val.real), static_cast(delta), width)); + float imag = static_cast(__shfl_down_sync( + mask, static_cast(val.imag), static_cast(delta), width)); + return phi::dtype::complex(real, imag); +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( + unsigned mask, phi::dtype::complex val, int delta, int width) { + double real = + static_cast(__shfl_down_sync(mask, + static_cast(val.real), + static_cast(delta), + width)); + double imag = + static_cast(__shfl_down_sync(mask, + static_cast(val.imag), + static_cast(delta), + width)); + return phi::dtype::complex(real, imag); +} + +template <> +__forceinline__ __device__ phi::dtype::float16 CudaShuffleXorSync( + unsigned mask, phi::dtype::float16 val, int width) { + return phi::dtype::float16(__shfl_xor_sync(mask, val.to_half(), width)); +} + +template <> +__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync( + unsigned mask, phi::dtype::bfloat16 val, int width) { +#if defined(PADDLE_CUDA_BF16) + return phi::dtype::bfloat16( + __shfl_xor_sync(mask, static_cast(val), width)); +#else + PADDLE_ENFORCE( + false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11."); +#endif +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( + unsigned mask, phi::dtype::complex val, int width) { + float real = static_cast( + __shfl_xor_sync(mask, static_cast(val.real), width)); + float imag = static_cast( + __shfl_xor_sync(mask, static_cast(val.imag), width)); + return phi::dtype::complex(real, imag); +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( + unsigned mask, phi::dtype::complex val, int width) { + double real = static_cast( + __shfl_xor_sync(mask, static_cast(val.real), width)); + double imag = static_cast( + __shfl_xor_sync(mask, static_cast(val.imag), width)); + return phi::dtype::complex(real, imag); +} + +template +__forceinline__ __device__ T +CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) { + return __shfl_sync(mask, val, src_line, width); +} + +template +HOSTDEVICE T Infinity() { + return INFINITY; +} + +template +__device__ T reduceSum(T val, int tid, int len) { + // NOTE(zcd): The warp size should be taken from the + // parameters of the GPU but not specified as 32 simply. + // To make the reduceSum more efficiently, + // I use Warp-Level Parallelism and assume the Warp size + // is 32 which may be different for different GPU, + // but most card's warp size is 32. + const int warpSize = 32; + __shared__ T shm[warpSize]; + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, tid < len); + + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset); + + if (tid < warpSize) shm[tid] = 0; + __syncthreads(); + + if (tid % warpSize == 0) { + shm[tid / warpSize] = val; + } + __syncthreads(); + + CREATE_SHFL_MASK(mask, tid < warpSize); + + if (tid < warpSize) { + val = shm[tid]; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset); + } + return val; +} + +} // namespace gpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/backends/gpu/gpu_device_function.h b/paddle/phi/backends/gpu/gpu_device_function.h new file mode 100644 index 00000000000..0f79e2a645a --- /dev/null +++ b/paddle/phi/backends/gpu/gpu_device_function.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#ifdef PADDLE_WITH_HIP +#include "paddle/phi/backends/gpu/rocm/rocm_device_function.h" +#else +#include "paddle/phi/backends/gpu/cuda/cuda_device_function.h" +#endif + +#endif diff --git a/paddle/phi/backends/gpu/rocm/rocm_device_function.h b/paddle/phi/backends/gpu/rocm/rocm_device_function.h new file mode 100644 index 00000000000..6f5d684075f --- /dev/null +++ b/paddle/phi/backends/gpu/rocm/rocm_device_function.h @@ -0,0 +1,165 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// NOTE(): support float16 to half in header file. +#define PADDLE_CUDA_FP16 +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" + +namespace phi { +namespace backends { +namespace gpu { + +#define CREATE_SHFL_MASK(mask, predicate) mask = __ballot((predicate)) + +#define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kPowerOfTwoDim = (dim); \ + __VA_ARGS__; \ + } break + +#define CUDA_LAUNCH_KERNEL_HELPER(...) \ + CUDA_LAUNCH_KERNEL_BASE(1024, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(512, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__); \ + CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__); + +template +__forceinline__ __device__ T +CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { + return __shfl_down(val, delta, width); +} + +template +__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, + T val, + int width = warpSize) { + return __shfl_xor(val, width); +} + +template <> +__forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync( + unsigned mask, phi::dtype::float16 val, int delta, int width) { + return phi::dtype::float16(__shfl_down( + static_cast(val), static_cast(delta), width)); +} + +template <> +__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( + unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { + return phi::dtype::bfloat16(__shfl_down( + static_cast(val), static_cast(delta), width)); +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( + unsigned mask, phi::dtype::complex val, int delta, int width) { + float real = __shfl_down(val.real, delta, width); + float imag = __shfl_down(val.imag, delta, width); + return phi::dtype::complex(real, imag); +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleDownSync( + unsigned mask, phi::dtype::complex val, int delta, int width) { + double real = __shfl_down(val.real, delta, width); + double imag = __shfl_down(val.imag, delta, width); + return phi::dtype::complex(real, imag); +} + +template <> +__forceinline__ __device__ phi::dtype::float16 CudaShuffleXorSync( + unsigned mask, phi::dtype::float16 val, int width) { + return phi::dtype::float16(__shfl_xor(static_cast(val), width)); +} + +template <> +__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleXorSync( + unsigned mask, phi::dtype::bfloat16 val, int width) { + return phi::dtype::bfloat16(__shfl_xor(static_cast(val), width)); +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( + unsigned mask, phi::dtype::complex val, int width) { + float real = __shfl_xor(val.real, width); + float imag = __shfl_xor(val.imag, width); + return phi::dtype::complex(real, imag); +} + +template <> +__forceinline__ __device__ phi::dtype::complex CudaShuffleXorSync( + unsigned mask, phi::dtype::complex val, int width) { + double real = __shfl_xor(val.real, width); + double imag = __shfl_xor(val.imag, width); + return phi::dtype::complex(real, imag); +} + +template +__forceinline__ __device__ T +CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) { + return __shfl(val, src_line, width); +} + +template +HOSTDEVICE T Infinity() { + return INFINITY; +} + +template +__device__ T reduceSum(T val, int tid, int len) { + // NOTE(zcd): The warp size should be taken from the + // parameters of the GPU but not specified as 32 simply. + // To make the reduceSum more efficiently, + // I use Warp-Level Parallelism and assume the Warp size + // is 32 which may be different for different GPU, + // but most card's warp size is 32. +#ifdef PADDLE_WITH_HIP + const int warpSize = 64; +#else + const int warpSize = 32; +#endif + __shared__ T shm[warpSize]; + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, tid < len); + + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset); + + if (tid < warpSize) shm[tid] = 0; + __syncthreads(); + + if (tid % warpSize == 0) { + shm[tid / warpSize] = val; + } + __syncthreads(); + + CREATE_SHFL_MASK(mask, tid < warpSize); + + if (tid < warpSize) { + val = shm[tid]; + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset); + } + return val; +} + +} // namespace gpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index c55ce6a89ae..65f21e5b7f1 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -24,7 +24,7 @@ limitations under the License. */ #if defined(__NVCC__) || defined(__HIPCC__) // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" @@ -504,7 +504,7 @@ static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x, } if (dd) { int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (tid == 0) { dd[bid] = val; } @@ -527,7 +527,7 @@ static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x, } if (dd) { int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (tid == 0) { dd[bid] = val; } @@ -569,7 +569,7 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel( } if (dy) { int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (tid == 0) { dy[bid] = val; } @@ -590,7 +590,7 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel( } if (dx) { int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (tid == 0) { dx[bid] = val; } @@ -636,7 +636,8 @@ static __global__ void FastCommonGradBroadcastCUDAKernelHeight(const T *x, if (dy) { T my_val = sdata[THREAD_ID_X][THREAD_ID_Y]; for (int i = warpSize >> 1; i > 0; i >>= 1) { - my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + my_val += + phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); } __syncthreads(); if ((THREAD_ID_X == 0)) { @@ -665,7 +666,8 @@ static __global__ void FastCommonGradBroadcastCUDAKernelHeight(const T *x, if (dy) { T my_val = sdata[THREAD_ID_X][THREAD_ID_Y]; for (int i = warpSize >> 1; i > 0; i >>= 1) { - my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + my_val += + phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); } __syncthreads(); if ((THREAD_ID_X == 0)) { @@ -709,7 +711,7 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x, if (dy) { h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (THREAD_ID_X == 0) { dy[j] = val; } @@ -726,7 +728,7 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x, if (dy) { h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (THREAD_ID_X == 0) { dy[j] = val; } @@ -764,7 +766,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x, if (dy) { h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (THREAD_ID_X == 0) { dy[j] = val; } @@ -783,7 +785,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x, if (dx) { h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (THREAD_ID_X == 0) { dx[j] = val; } @@ -835,7 +837,8 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel( if (dy) { T my_val = sdata[THREAD_ID_X][THREAD_ID_Y]; for (int i = warpSize >> 1; i > 0; i >>= 1) - my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + my_val += + phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); __syncthreads(); if ((THREAD_ID_X == 0)) { sdata[0][THREAD_ID_Y] = my_val; @@ -866,7 +869,8 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel( if (dx) { T my_val = sdata[THREAD_ID_X][THREAD_ID_Y]; for (int i = warpSize >> 1; i > 0; i >>= 1) - my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + my_val += + phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); __syncthreads(); if ((THREAD_ID_X == 0)) { sdata[0][THREAD_ID_Y] = my_val; @@ -921,7 +925,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x, if (dy) { int h = pre * post; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (THREAD_ID_X == 0) { dy[j] = val; } @@ -948,7 +952,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x, if (dx) { int h = pre * post; h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); + val = phi::backends::gpu::reduceSum(val, tid, h); if (THREAD_ID_X == 0) { dx[j] = val; } @@ -1054,7 +1058,7 @@ __global__ void CommonGradBroadcastCUDAKernel(const int *x_strides_array, out_index = C_index; val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]); } - val = paddle::platform::reduceSum(val, tid, thread_num); + val = phi::backends::gpu::reduceSum(val, tid, thread_num); if (THREAD_ID_X == 0) { dx[i] = val; } diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 9719fbd8816..1b1a55b25c5 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -33,8 +33,8 @@ namespace cub = hipcub; #endif #ifndef PADDLE_WITH_XPU_KP -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #endif diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 5e75909649a..2c2ca16e262 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -14,8 +14,8 @@ limitations under the License. */ #include "paddle/phi/kernels/activation_grad_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index df8ae72346a..5168a1de073 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -14,8 +14,8 @@ limitations under the License. */ #include "paddle/phi/kernels/activation_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu b/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu index a7a82236a40..886aaa76e41 100644 --- a/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/affine_grid_grad_kernel.cu @@ -16,10 +16,10 @@ #include "paddle/phi/kernels/affine_grid_grad_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/gpu/affine_grid_kernel.cu b/paddle/phi/kernels/gpu/affine_grid_kernel.cu index 499ed260eef..8274e687512 100644 --- a/paddle/phi/kernels/gpu/affine_grid_kernel.cu +++ b/paddle/phi/kernels/gpu/affine_grid_kernel.cu @@ -16,10 +16,10 @@ #include "paddle/phi/kernels/affine_grid_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu index 5d40304c5e0..df3e4bd0cf1 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu @@ -24,8 +24,8 @@ namespace cub = hipcub; #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index 76201a1077e..bee9fc801b7 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -24,8 +24,8 @@ namespace cub = hipcub; #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index 5da0ae96e6b..9ed88135041 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -27,7 +27,7 @@ limitations under the License. */ namespace cub = hipcub; #endif -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -92,7 +92,7 @@ class DepthwiseConvFilterGradFunctor { template __forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val += platform::CudaShuffleDownSync(lane_mask, val, mask); + val += phi::backends::gpu::CudaShuffleDownSync(lane_mask, val, mask); return val; } diff --git a/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu b/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu index 8f4beaa2677..6e8b12c4b1b 100644 --- a/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu @@ -14,7 +14,7 @@ #include "paddle/phi/kernels/grid_sample_grad_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" diff --git a/paddle/phi/kernels/gpu/group_norm_utils.h b/paddle/phi/kernels/gpu/group_norm_utils.h index 00986817c61..3cb13692d52 100644 --- a/paddle/phi/kernels/gpu/group_norm_utils.h +++ b/paddle/phi/kernels/gpu/group_norm_utils.h @@ -22,7 +22,7 @@ namespace cub = hipcub; #endif -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" diff --git a/paddle/phi/kernels/gpu/interpolate_kernel.cu b/paddle/phi/kernels/gpu/interpolate_kernel.cu index 625718e8f4b..8135e73142f 100644 --- a/paddle/phi/kernels/gpu/interpolate_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_kernel.cu @@ -14,8 +14,8 @@ #include "paddle/phi/kernels/interpolate_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/amp_type_traits.h" diff --git a/paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu b/paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu index 4bc8c205025..d1cc738e2b0 100644 --- a/paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/affine_grid_grad_kernel.cu @@ -15,11 +15,11 @@ #ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/affine_grid_grad_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/gpudnn/affine_grid_kernel.cu b/paddle/phi/kernels/gpudnn/affine_grid_kernel.cu index 98f200480d4..6c5d305abbf 100644 --- a/paddle/phi/kernels/gpudnn/affine_grid_kernel.cu +++ b/paddle/phi/kernels/gpudnn/affine_grid_kernel.cu @@ -15,11 +15,11 @@ #ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/affine_grid_kernel.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 99cd4c9b6d8..a81357e99b5 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -24,8 +24,8 @@ limitations under the License. */ #include "paddle/phi/kernels/primitive/kernel_primitives.h" // See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #define MATRIX_SOFTMAX_ALIGN_BYTES 16 #define MATRIX_SOFTMAX_THREAHOLD 100000 @@ -133,7 +133,7 @@ __device__ __forceinline__ void WarpReduceSum(T* sum) { #pragma unroll for (int i = 0; i < BatchSize; ++i) { T sum_val = - paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); sum[i] = sum[i] + sum_val; } } @@ -146,7 +146,7 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) { #pragma unroll for (int i = 0; i < BatchSize; ++i) { T max_val = - paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); sum[i] = max(sum[i], max_val); } } diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index b3da4197662..1dfcde4e5dd 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -21,7 +21,7 @@ #include #endif -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/common/float16.h" namespace phi { @@ -65,7 +65,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); for (int stride = details::kWarpSize / 2; stride > 0; stride >>= 1) { - T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); + T temp = phi::backends::gpu::CudaShuffleDownSync(mask, val, stride); val = reducer(val, temp); } return val; @@ -110,7 +110,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { unsigned mask = 0u; CREATE_SHFL_MASK(mask, true); for (int stride = 1; stride < block_dim_x; stride <<= 1) { - T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); + T temp = phi::backends::gpu::CudaShuffleDownSync(mask, val, stride); val = reducer(val, temp); } __syncthreads(); -- GitLab