未验证 提交 effebd41 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] fix multihead_matmul (#54108)

* [ROCM] fix multihead_matmul

* skip bf16 uts

* update
上级 2186fe16
...@@ -261,9 +261,9 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, ...@@ -261,9 +261,9 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
const int seq_len, const int seq_len,
const unsigned mask) { const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
float tmp = threadIdx.x < seq_len float tmp = threadIdx.x < seq_len
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] + ? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
...@@ -281,15 +281,16 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, ...@@ -281,15 +281,16 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
// HIP defined __HIP_NO_HALF_CONVERSIONS__ // HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd #ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template <> template <>
__global__ void SoftmaxKernelWithEltadd<half>(half *qk_buf_, __global__ void SoftmaxKernelWithEltadd<half>(
const half *bias_qk_, half *qk_buf_,
const int batch_size, const half *bias_qk_,
const int head_num, const int batch_size,
const int seq_len, const int head_num,
const unsigned mask) { const int seq_len,
const phi::funcs::warp_mask_t mask) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
float tmp = threadIdx.x < seq_len float tmp = threadIdx.x < seq_len
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] + ? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
...@@ -312,10 +313,10 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, ...@@ -312,10 +313,10 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
const int seq_len, const int seq_len,
const unsigned mask) { const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x; int idx = threadIdx.x;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
float2 tmp = idx < seq_len float2 tmp = idx < seq_len
? phi::funcs::ToFloat2<T>(qk_buf_[idx + qk_offset] + ? phi::funcs::ToFloat2<T>(qk_buf_[idx + qk_offset] +
...@@ -335,19 +336,20 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, ...@@ -335,19 +336,20 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
} }
template <> template <>
__global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_, __global__ void SoftmaxKernelWithEltadd2<half2>(
const half2 *bias_qk_, half2 *qk_buf_,
const int batch_size, const half2 *bias_qk_,
const int head_num, const int batch_size,
const int seq_len, const int head_num,
const unsigned mask) { const int seq_len,
const phi::funcs::warp_mask_t mask) {
// operator "+" of half only suppotted after cuda version 10.0 // operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \ #if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000) (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
int idx = threadIdx.x; int idx = threadIdx.x;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
float2 tmp = idx < seq_len float2 tmp = idx < seq_len
? phi::funcs::ToFloat2<half2>(qk_buf_[idx + qk_offset] + ? phi::funcs::ToFloat2<half2>(qk_buf_[idx + qk_offset] +
...@@ -368,14 +370,15 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_, ...@@ -368,14 +370,15 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_,
} }
template <typename T> template <typename T>
__global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf, __global__ void SoftmaxKernelWithEltaddForLarge(
const T *bias_qk, T *qk_buf,
const int batch_size, const T *bias_qk,
const int head_num, const int batch_size,
const int seq_len, const int head_num,
const unsigned mask) { const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
T stride_max = -1e20f; T stride_max = -1e20f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
...@@ -406,15 +409,16 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf, ...@@ -406,15 +409,16 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
// HIP defined __HIP_NO_HALF_CONVERSIONS__ // HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd #ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template <> template <>
__global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, __global__ void SoftmaxKernelWithEltaddForLarge(
const half *bias_qk, half *qk_buf,
const int batch_size, const half *bias_qk,
const int head_num, const int batch_size,
const int seq_len, const int head_num,
const unsigned mask) { const int seq_len,
const phi::funcs::warp_mask_t mask) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
float stride_max = -1e20f; float stride_max = -1e20f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
...@@ -444,14 +448,15 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, ...@@ -444,14 +448,15 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
#endif // @} End Half kernel: SoftmaxKernelWithEltadd #endif // @} End Half kernel: SoftmaxKernelWithEltadd
template <typename T> template <typename T>
__global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, __global__ void SoftmaxKernelWithEltaddForLarge2(
const T *bias_qk_, T *qk_buf_,
const int batch_size, const T *bias_qk_,
const int head_num, const int batch_size,
const int seq_len, const int head_num,
const unsigned mask) { const int seq_len,
const phi::funcs::warp_mask_t mask) {
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
float2 stride_max = make_float2(-1e20f, -1e20f); float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
...@@ -484,19 +489,20 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, ...@@ -484,19 +489,20 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
} }
template <> template <>
__global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_, __global__ void SoftmaxKernelWithEltaddForLarge2(
const half2 *bias_qk_, half2 *qk_buf_,
const int batch_size, const half2 *bias_qk_,
const int head_num, const int batch_size,
const int seq_len, const int head_num,
const unsigned mask) { const int seq_len,
const phi::funcs::warp_mask_t mask) {
// operator "+" of half only suppotted after cuda version 10.0 // operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake // HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \ #if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000) (CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0); assert(blockDim.x % WARP_SIZE == 0);
float2 stride_max = make_float2(-1e20f, -1e20f); float2 stride_max = make_float2(-1e20f, -1e20f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
...@@ -637,7 +643,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, ...@@ -637,7 +643,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
} }
} }
if (blockDim.x <= 32) { if (blockDim.x <= WARP_SIZE) {
phi::funcs::WarpReduceMaxV2<float, NUM>(local_max); phi::funcs::WarpReduceMaxV2<float, NUM>(local_max);
} else { } else {
phi::funcs::BlockReduceMaxV2<float, NUM>(local_max); phi::funcs::BlockReduceMaxV2<float, NUM>(local_max);
...@@ -672,7 +678,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, ...@@ -672,7 +678,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
} }
} }
if (blockDim.x <= 32) { if (blockDim.x <= WARP_SIZE) {
phi::funcs::WarpReduceSumV2<float, NUM>(local_sum); phi::funcs::WarpReduceSumV2<float, NUM>(local_sum);
} else { } else {
phi::funcs::BlockReduceSumV2<float, NUM>(local_sum); phi::funcs::BlockReduceSumV2<float, NUM>(local_sum);
...@@ -761,7 +767,10 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, ...@@ -761,7 +767,10 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
// Align block to 32, also limit seq_len to max block size. // Align block to 32, also limit seq_len to max block size.
if (seq_len % 2 == 0) { if (seq_len % 2 == 0) {
block = (seq_len <= 64) ? 32 : ((seq_len + 63) / 64) * 32; block =
(seq_len <= (2 * WARP_SIZE))
? WARP_SIZE
: ((seq_len + (2 * WARP_SIZE - 1)) / (2 * WARP_SIZE)) * WARP_SIZE;
if (std::is_same<T, float>::value) { if (std::is_same<T, float>::value) {
SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>( SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>(
reinterpret_cast<float2 *>(qk_buf_), reinterpret_cast<float2 *>(qk_buf_),
...@@ -780,7 +789,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, ...@@ -780,7 +789,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
"cuda_arch<700")); "cuda_arch<700"));
#else #else
dim3 grid(seq_len, batch_size, head_num); dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + 31) / 32 * 32); dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
SOFTMAX_KERNEL_WITH_MASK(1); SOFTMAX_KERNEL_WITH_MASK(1);
#endif #endif
} else { } else {
...@@ -794,7 +803,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, ...@@ -794,7 +803,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
} }
} }
} else { } else {
block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32; block = (seq_len <= WARP_SIZE)
? WARP_SIZE
: ((seq_len + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>( SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK); qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
} }
...@@ -820,7 +831,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, ...@@ -820,7 +831,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
"cuda_arch<700")); "cuda_arch<700"));
#else #else
dim3 grid(seq_len, batch_size, head_num); dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + 31) / 32 * 32); dim3 block((seq_len / 2 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
if (block.x > 0 && block.x <= 1024) { if (block.x > 0 && block.x <= 1024) {
SOFTMAX_KERNEL_WITH_MASK(1); SOFTMAX_KERNEL_WITH_MASK(1);
} else if (block.x <= 2048) { } else if (block.x <= 2048) {
...@@ -1176,8 +1187,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, ...@@ -1176,8 +1187,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
float eps, float eps,
gpuStream_t stream) { gpuStream_t stream) {
int block = num / hidden; int block = num / hidden;
if (hidden <= 32) { if (hidden <= WARP_SIZE) {
const int threads = 32; const int threads = WARP_SIZE;
SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>( SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
num, hidden, input1, input2, output, scale, bias, eps); num, hidden, input1, input2, output, scale, bias, eps);
} else if (hidden <= 128) { } else if (hidden <= 128) {
......
...@@ -374,12 +374,20 @@ void BindPlace(pybind11::module &m) { // NOLINT ...@@ -374,12 +374,20 @@ void BindPlace(pybind11::module &m) { // NOLINT
.def("__str__", string::to_string<const platform::CUDAPlace &>); .def("__str__", string::to_string<const platform::CUDAPlace &>);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16 // Only GPUs with Compute Capability >= 53 support float16
#ifdef PADDLE_WITH_HIP
return true;
#else
return platform::GetGPUComputeCapability(place.device) >= 53; return platform::GetGPUComputeCapability(place.device) >= 53;
#endif
}); });
m.def("is_bfloat16_supported", [](const platform::CUDAPlace &place) -> bool { m.def("is_bfloat16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 80 support bfloat16 // Only GPUs with Compute Capability >= 80 support bfloat16
#ifdef PADDLE_WITH_HIP
return false;
#else
return platform::GetGPUComputeCapability(place.device) >= 80; return platform::GetGPUComputeCapability(place.device) >= 80;
#endif
}); });
#endif #endif
py::class_<platform::XPUPlace> xpuplace(m, "XPUPlace", R"DOC( py::class_<platform::XPUPlace> xpuplace(m, "XPUPlace", R"DOC(
......
...@@ -163,12 +163,28 @@ struct KeyValuePair<half> { ...@@ -163,12 +163,28 @@ struct KeyValuePair<half> {
} }
}; };
// NOTE(wangran16): The warpSize variable is of type int and contains the warp
// size (in threads) for the target device. Note that all current NVIDIA devices
// return 32 for this variable, and all current AMD devices return 64. Device
// code should use the warpSize built-in to develop portable wave-aware code.
#ifdef PADDLE_WITH_HIP
#define FINAL_MASK 0xffffffffffffffffUL
#define HALF_WARP 32
#define WARP_SIZE 64
#define WARP_SIZE_WIDTH 6
#define WARP_SIZE_WIDTH_MASK 0x3f
typedef u_int64_t warp_mask_t;
#else
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
#define HALF_WARP 16 #define HALF_WARP 16
#define WARP_SIZE 32 #define WARP_SIZE 32
#define WARP_SIZE_WIDTH 5
#define WARP_SIZE_WIDTH_MASK 0x1f
typedef unsigned warp_mask_t;
#endif
template <typename T> template <typename T>
__inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { __inline__ __device__ T WarpReduceSum(T val, warp_mask_t lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val += __shfl_xor_sync(lane_mask, val, mask, warpSize); val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
...@@ -180,10 +196,10 @@ __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { ...@@ -180,10 +196,10 @@ __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template <typename T> template <typename T>
__inline__ __device__ T BlockReduceSum(T val, unsigned mask) { __inline__ __device__ T BlockReduceSum(T val, warp_mask_t mask) {
static __shared__ T shared[WARP_SIZE]; static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> WARP_SIZE_WIDTH;
val = WarpReduceSum<T>(val, mask); val = WarpReduceSum<T>(val, mask);
...@@ -193,7 +209,7 @@ __inline__ __device__ T BlockReduceSum(T val, unsigned mask) { ...@@ -193,7 +209,7 @@ __inline__ __device__ T BlockReduceSum(T val, unsigned mask) {
__syncthreads(); __syncthreads();
// align block_span to warpSize // align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5; int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH;
val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f); val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = WarpReduceSum<T>(val, mask); val = WarpReduceSum<T>(val, mask);
...@@ -208,8 +224,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) { ...@@ -208,8 +224,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM; i++) { for (int i = 0; i < NUM; i++) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, WARP_SIZE);
} }
return (T)(0.0f); return (T)(0.0f);
} }
...@@ -217,8 +233,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) { ...@@ -217,8 +233,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) {
template <typename T, int NUM> template <typename T, int NUM>
__inline__ __device__ T BlockReduceSumV2(T *val) { __inline__ __device__ T BlockReduceSumV2(T *val) {
static __shared__ T shared[NUM][33]; static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> WARP_SIZE_WIDTH;
WarpReduceSumV2<T, NUM>(val); WarpReduceSumV2<T, NUM>(val);
...@@ -231,7 +247,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) { ...@@ -231,7 +247,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) {
__syncthreads(); __syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f); bool is_mask = threadIdx.x < (blockDim.x / static_cast<float>(WARP_SIZE));
#pragma unroll #pragma unroll
for (int i = 0; i < NUM; i++) { for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f); val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
...@@ -241,7 +257,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) { ...@@ -241,7 +257,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) {
} }
template <typename T> template <typename T>
__inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { __inline__ __device__ T WarpReduceMax(T val, warp_mask_t lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
...@@ -256,14 +272,15 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) { ...@@ -256,14 +272,15 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM; i++) { for (int i = 0; i < NUM; i++) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); val[i] =
max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, WARP_SIZE));
} }
return (T)(0.0f); return (T)(0.0f);
} }
template <typename T> template <typename T>
__inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { __inline__ __device__ T WarpReduceMin(T val, warp_mask_t lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
...@@ -276,7 +293,7 @@ __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { ...@@ -276,7 +293,7 @@ __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) {
/* Calculate the minimum of all elements in a warp when actual quantity of /* Calculate the minimum of all elements in a warp when actual quantity of
* threads are less than warpSize.*/ * threads are less than warpSize.*/
template <typename T> template <typename T>
__inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) { __inline__ __device__ T PartialWarpReduceMin(T val, warp_mask_t lane_mask) {
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
T warp_val = __shfl_sync(lane_mask, val, 0, warpSize); T warp_val = __shfl_sync(lane_mask, val, 0, warpSize);
#else #else
...@@ -297,10 +314,10 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) { ...@@ -297,10 +314,10 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) {
/* Calculate the maximum of all elements in a block */ /* Calculate the maximum of all elements in a block */
template <typename T> template <typename T>
__inline__ __device__ T BlockReduceMax(T val, unsigned mask) { __inline__ __device__ T BlockReduceMax(T val, warp_mask_t mask) {
static __shared__ T shared[WARP_SIZE]; static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> WARP_SIZE_WIDTH;
val = WarpReduceMax(val, mask); val = WarpReduceMax(val, mask);
...@@ -309,7 +326,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) { ...@@ -309,7 +326,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) {
__syncthreads(); __syncthreads();
// align block_span to warpSize // align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5; int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH;
val = (lane < block_span) ? shared[lane] : -1e10f; val = (lane < block_span) ? shared[lane] : -1e10f;
val = WarpReduceMax(val, mask); val = WarpReduceMax(val, mask);
...@@ -318,9 +335,9 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) { ...@@ -318,9 +335,9 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) {
template <typename T, int NUM> template <typename T, int NUM>
__inline__ __device__ T BlockReduceMaxV2(T *val) { __inline__ __device__ T BlockReduceMaxV2(T *val) {
static __shared__ T shared[32][NUM]; static __shared__ T shared[WARP_SIZE][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx int wid = threadIdx.x >> WARP_SIZE_WIDTH; // warp idx
WarpReduceMaxV2<T, NUM>(val); // get maxx in each warp WarpReduceMaxV2<T, NUM>(val); // get maxx in each warp
...@@ -335,7 +352,7 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) { ...@@ -335,7 +352,7 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) {
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32 // blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f); bool is_mask = threadIdx.x < (blockDim.x / static_cast<float>(WARP_SIZE));
#pragma unroll #pragma unroll
for (int i = 0; i < NUM; i++) { for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[lane][i] : (T)-1e20f; val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
...@@ -347,17 +364,17 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) { ...@@ -347,17 +364,17 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) {
/* Calculate the minimum of all elements in a block */ /* Calculate the minimum of all elements in a block */
template <typename T> template <typename T>
__inline__ __device__ T BlockReduceMin(T val, unsigned mask) { __inline__ __device__ T BlockReduceMin(T val, warp_mask_t mask) {
static __shared__ T shared[WARP_SIZE]; static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> WARP_SIZE_WIDTH;
val = WarpReduceMin(val, mask); val = WarpReduceMin(val, mask);
if (lane == 0) shared[wid] = val; if (lane == 0) shared[wid] = val;
__syncthreads(); __syncthreads();
// align block_span to warpSize // align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5; int block_span = (blockDim.x + warpSize - 1) >> WARP_SIZE_WIDTH;
val = (lane < block_span) ? shared[lane] : 1e10f; val = (lane < block_span) ? shared[lane] : 1e10f;
val = WarpReduceMin(val, mask); val = WarpReduceMin(val, mask);
...@@ -367,11 +384,11 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) { ...@@ -367,11 +384,11 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) {
/* Calculate the minimum of all elements in a warp when actual quantity of /* Calculate the minimum of all elements in a warp when actual quantity of
* threads are less than warpSize.*/ * threads are less than warpSize.*/
template <typename T> template <typename T>
__inline__ __device__ T PartialBlockReduceMin(T val, unsigned mask) { __inline__ __device__ T PartialBlockReduceMin(T val, warp_mask_t mask) {
static __shared__ T shared[WARP_SIZE]; static __shared__ T shared[WARP_SIZE];
static __shared__ T min_value; static __shared__ T min_value;
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & WARP_SIZE_WIDTH_MASK;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> WARP_SIZE_WIDTH;
val = PartialWarpReduceMin(val, mask); val = PartialWarpReduceMin(val, mask);
if (lane == 0) shared[wid] = val; if (lane == 0) shared[wid] = val;
......
...@@ -278,7 +278,8 @@ class TestSigmoid_ZeroDim(TestSigmoid): ...@@ -278,7 +278,8 @@ class TestSigmoid_ZeroDim(TestSigmoid):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
"core is not compiled with CUDA",
) )
class TestSigmoidBF16(OpTest): class TestSigmoidBF16(OpTest):
def setUp(self): def setUp(self):
...@@ -1237,7 +1238,8 @@ class TestSqrt_ZeroDim(TestSqrt): ...@@ -1237,7 +1238,8 @@ class TestSqrt_ZeroDim(TestSqrt):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
"core is not compiled with CUDA",
) )
class TestSqrtBF16(OpTest): class TestSqrtBF16(OpTest):
def setUp(self): def setUp(self):
...@@ -3060,7 +3062,8 @@ class TestSquare_ZeroDim(TestSquare): ...@@ -3060,7 +3062,8 @@ class TestSquare_ZeroDim(TestSquare):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
"core is not compiled with CUDA",
) )
class TestSquareBF16(OpTest): class TestSquareBF16(OpTest):
def setUp(self): def setUp(self):
...@@ -3350,7 +3353,8 @@ class TestSoftplus_ZeroDim(TestSoftplus): ...@@ -3350,7 +3353,8 @@ class TestSoftplus_ZeroDim(TestSoftplus):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
"core is not compiled with CUDA",
) )
class TestSoftplusBF16(OpTest): class TestSoftplusBF16(OpTest):
def setUp(self): def setUp(self):
......
...@@ -154,6 +154,9 @@ class TestScaleFp16Op(TestScaleOp): ...@@ -154,6 +154,9 @@ class TestScaleFp16Op(TestScaleOp):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
@unittest.skipIf(
not core.is_compiled_with_rocm(), "core is not compiled with CUDA"
)
class TestScaleBF16Op(OpTest): class TestScaleBF16Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scale" self.op_type = "scale"
......
...@@ -392,7 +392,8 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): ...@@ -392,7 +392,8 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA" not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
"core is not compiled with CUDA",
) )
class TestSoftmaxBF16Op(OpTest): class TestSoftmaxBF16Op(OpTest):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册