未验证 提交 4da1a0fe 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] remove "gpu_device_function.h" in fluid. (#48117)

* move "paddle/phi/backends/gpu/gpu_device_function.h" to phi

* update copyright years

* rm "fluid/platform/device/gpu/gpu_device_function.h" in phi

* rm dependence to "gpu_device_function.h" in fluid

* rm gpu_device_function.h etc in fluid

* fix rocm-complie bugs

* fix cuda_helper_test.cu bugs
上级 2995f742
......@@ -13,7 +13,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
namespace paddle {
......
......@@ -42,7 +42,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/gpu/elementwise_grad.h"
......@@ -982,7 +982,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
val += platform::CudaShuffleXorSync(0xFFFFFFFF, val, i);
val += phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, val, i);
}
size_t idx_j = j + threadIdx.y;
......@@ -1004,7 +1004,8 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
#pragma unroll
for (int i = BLOCK_X >> 1; i > 0; i >>= 1) {
// reduce sum with wrap
inter_val += platform::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
inter_val +=
phi::backends::gpu::CudaShuffleXorSync(0xFFFFFFFF, inter_val, i);
}
if (threadIdx.x == 0 && (idx_j < w)) d_intermediate[idx_j] = inter_val;
}
......@@ -1160,14 +1161,14 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
if (BcastY) {
if (dy) {
val = paddle::platform::reduceSum(val, tid, h);
val = phi::backends::gpu::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
val = paddle::platform::reduceSum(val, tid, h);
val = phi::backends::gpu::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
......@@ -1175,7 +1176,7 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
}
if (!SameShapeOfIntermediateOutAndOut) {
if (d_intermediate) {
inter_val = paddle::platform::reduceSum(inter_val, tid, h);
inter_val = phi::backends::gpu::reduceSum(inter_val, tid, h);
if (threadIdx.x == 0) {
d_intermediate[j] = inter_val;
}
......
......@@ -22,9 +22,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......
......@@ -19,8 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
namespace paddle {
namespace operators {
......
......@@ -22,10 +22,10 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h"
......
......@@ -25,8 +25,8 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fused_gate_attention.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......
......@@ -26,9 +26,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......
......@@ -21,7 +21,7 @@ namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
namespace paddle {
......
......@@ -25,8 +25,8 @@ namespace cub = hipcub;
#include <iostream>
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
......@@ -55,7 +55,7 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
......
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/beam_search.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
namespace paddle {
namespace operators {
......
......@@ -12,7 +12,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......@@ -242,7 +242,7 @@ __global__ void RowConvGradFilterImproved(const T *in,
for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += platform::CudaShuffleDownSync(mask, val, offset);
val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset);
}
__syncthreads();
......@@ -307,7 +307,7 @@ __global__ void RowConvGradFilter(const T *in,
for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += platform::CudaShuffleDownSync(mask, val, offset);
val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset);
}
__syncthreads();
......
......@@ -26,9 +26,9 @@ limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#define FINAL_MASK 0xffffffff
......@@ -283,8 +283,10 @@ __forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input,
if (largest) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
T tmp_val =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (input.v < tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
input.v = tmp_val;
input.id = tmp_id;
......@@ -293,8 +295,10 @@ __forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input,
} else {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
T tmp_val = platform::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id = platform::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
T tmp_val =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.v, offset);
int tmp_id =
phi::backends::gpu::CudaShuffleDownSync(FINAL_MASK, input.id, offset);
if (input.v > tmp_val || (input.v == tmp_val && input.id > tmp_id)) {
input.v = tmp_val;
input.id = tmp_id;
......@@ -357,7 +361,8 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (tid_max / 32 == wid) {
if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength)
if (phi::backends::gpu::CudaShuffleSync(mask, *beam, tid_max % 32, 32) ==
MaxLength)
break;
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace platform {
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \
case (dim): { \
constexpr auto kPowerOfTwoDim = (dim); \
__VA_ARGS__; \
} break
#define CUDA_LAUNCH_KERNEL_HELPER(...) \
CUDA_LAUNCH_KERNEL_BASE(1024, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(512, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__);
template <typename T>
__forceinline__ __device__ T
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
}
template <typename T>
__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask,
T val,
int width = warpSize) {
return __shfl_xor_sync(mask, val, width);
}
template <>
__forceinline__ __device__ float16
CudaShuffleDownSync(unsigned mask, float16 val, int delta, int width) {
return float16(__shfl_down_sync(
mask, val.to_half(), static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ bfloat16
CudaShuffleDownSync(unsigned mask, bfloat16 val, int delta, int width) {
#if defined(PADDLE_CUDA_BF16)
return bfloat16(__shfl_down_sync(mask,
static_cast<nv_bfloat16>(val),
static_cast<unsigned>(delta),
width));
#else
PADDLE_ENFORCE(
false, "__shfl_down_sync with bfloat16 is not supported on cuda <= 11.");
#endif
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
float real = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.real), static_cast<unsigned>(delta), width));
float imag = static_cast<float>(__shfl_down_sync(
mask, static_cast<float>(val.imag), static_cast<unsigned>(delta), width));
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double>
CudaShuffleDownSync(unsigned mask,
paddle::platform::complex<double> val,
int delta,
int width) {
double real =
static_cast<double>(__shfl_down_sync(mask,
static_cast<double>(val.real),
static_cast<unsigned>(delta),
width));
double imag =
static_cast<double>(__shfl_down_sync(mask,
static_cast<double>(val.imag),
static_cast<unsigned>(delta),
width));
return paddle::platform::complex<double>(real, imag);
}
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
float16 val,
int width) {
return float16(__shfl_xor_sync(mask, val.to_half(), width));
}
template <>
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
bfloat16 val,
int width) {
#if defined(PADDLE_CUDA_BF16)
return bfloat16(__shfl_xor_sync(mask, static_cast<nv_bfloat16>(val), width));
#else
PADDLE_ENFORCE(
false, "__shfl_xor_sync with bfloat16 is not supported on cuda <= 11.");
#endif
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) {
float real = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.real), width));
float imag = static_cast<float>(
__shfl_xor_sync(mask, static_cast<float>(val.imag), width));
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<double> val, int width) {
double real = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.real), width));
double imag = static_cast<double>(
__shfl_xor_sync(mask, static_cast<double>(val.imag), width));
return paddle::platform::complex<double>(real, imag);
}
template <typename T>
__forceinline__ __device__ T
CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) {
return __shfl_sync(mask, val, src_line, width);
}
template <typename T>
HOSTDEVICE T Infinity() {
return INFINITY;
}
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const int warpSize = 32;
__shared__ T shm[warpSize];
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::CudaShuffleDownSync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
__syncthreads();
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
__syncthreads();
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
} // namespace platform
} // namespace paddle
......@@ -22,9 +22,9 @@
#include <random>
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
using paddle::platform::float16;
......@@ -214,7 +214,7 @@ static __forceinline__ __device__ T WarpReduceSum(T val) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
val += paddle::platform::CudaShuffleDownSync(mask, val, offset);
val += phi::backends::gpu::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/device/gpu/rocm/rocm_device_function.h"
#else
#include "paddle/fluid/platform/device/gpu/cuda/cuda_device_function.h"
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace platform {
#define CREATE_SHFL_MASK(mask, predicate) mask = __ballot((predicate))
#define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \
case (dim): { \
constexpr auto kPowerOfTwoDim = (dim); \
__VA_ARGS__; \
} break
#define CUDA_LAUNCH_KERNEL_HELPER(...) \
CUDA_LAUNCH_KERNEL_BASE(1024, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(512, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__);
template <typename T>
__forceinline__ __device__ T
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {
return __shfl_down(val, delta, width);
}
template <typename T>
__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask,
T val,
int width = warpSize) {
return __shfl_xor(val, width);
}
template <>
__forceinline__ __device__ float16
CudaShuffleDownSync(unsigned mask, float16 val, int delta, int width) {
return float16(__shfl_down(
static_cast<float>(val), static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ bfloat16
CudaShuffleDownSync(unsigned mask, bfloat16 val, int delta, int width) {
return bfloat16(__shfl_down(
static_cast<float>(val), static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleDownSync(
unsigned mask, paddle::platform::complex<float> val, int delta, int width) {
float real = __shfl_down(val.real, delta, width);
float imag = __shfl_down(val.imag, delta, width);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double>
CudaShuffleDownSync(unsigned mask,
paddle::platform::complex<double> val,
int delta,
int width) {
double real = __shfl_down(val.real, delta, width);
double imag = __shfl_down(val.imag, delta, width);
return paddle::platform::complex<double>(real, imag);
}
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
float16 val,
int width) {
return float16(__shfl_xor(static_cast<float>(val), width));
}
template <>
__forceinline__ __device__ bfloat16 CudaShuffleXorSync(unsigned mask,
bfloat16 val,
int width) {
return bfloat16(__shfl_xor(static_cast<float>(val), width));
}
template <>
__forceinline__ __device__ paddle::platform::complex<float> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<float> val, int width) {
float real = __shfl_xor(val.real, width);
float imag = __shfl_xor(val.imag, width);
return paddle::platform::complex<float>(real, imag);
}
template <>
__forceinline__ __device__ paddle::platform::complex<double> CudaShuffleXorSync(
unsigned mask, paddle::platform::complex<double> val, int width) {
double real = __shfl_xor(val.real, width);
double imag = __shfl_xor(val.imag, width);
return paddle::platform::complex<double>(real, imag);
}
template <typename T>
__forceinline__ __device__ T
CudaShuffleSync(unsigned mask, T val, int src_line, int width = 32) {
return __shfl(val, src_line, width);
}
template <typename T>
HOSTDEVICE T Infinity() {
return INFINITY;
}
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
#ifdef PADDLE_WITH_HIP
const int warpSize = 64;
#else
const int warpSize = 32;
#endif
__shared__ T shm[warpSize];
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::CudaShuffleDownSync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
__syncthreads();
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
__syncthreads();
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册