From 39210ed0989fa003aeff6c2f266d056742333805 Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Tue, 10 Jan 2023 10:15:41 +0800 Subject: [PATCH] Refine name style and MoeKernel (#49432) --- .../framework/details/nan_inf_utils_detail.cu | 6 +- .../plugin/merge_layernorm_op_plugin.cu | 4 +- .../plugin/preln_residual_bias_plugin.cu | 38 +- .../plugin/skip_merge_layernorm_op_plugin.cu | 4 +- .../operators/math/bert_encoder_functor.cu | 117 +---- .../operators/optimizers/lars_momentum_op.cu | 12 +- paddle/phi/kernels/funcs/math_cuda_utils.h | 104 +++- .../phi/kernels/fusion/cutlass/moe_kernel.cu | 49 +- .../kernels/fusion/cutlass/moe_kernel_impl.h | 471 +++++++----------- paddle/phi/kernels/gpu/dist_kernel.cu | 6 +- .../kernels/gpu/interpolate_grad_kernel.cu | 12 +- .../sparse/gpu/fused_attention_grad_kernel.cu | 2 +- .../sparse/gpu/fused_attention_kernel.cu | 4 +- .../kernels/sparse/gpu/softmax_grad_kernel.cu | 2 +- .../phi/kernels/sparse/gpu/softmax_kernel.cu | 4 +- 15 files changed, 338 insertions(+), 497 deletions(-) diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cu b/paddle/fluid/framework/details/nan_inf_utils_detail.cu index 4056fdc9cf..8754a33b66 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cu +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cu @@ -216,9 +216,9 @@ __device__ void BlockReduceMaxMinAndWrite(const T max_value, if (max_ptr && min_ptr && mean_ptr) { __syncthreads(); - T block_max_value = phi::funcs::blockReduceMax(max_value, FINAL_MASK); - T block_min_value = phi::funcs::blockReduceMin(min_value, FINAL_MASK); - T block_mean_value = phi::funcs::blockReduceSum(mean_value, FINAL_MASK); + T block_max_value = phi::funcs::BlockReduceMax(max_value, FINAL_MASK); + T block_min_value = phi::funcs::BlockReduceMin(min_value, FINAL_MASK); + T block_mean_value = phi::funcs::BlockReduceSum(mean_value, FINAL_MASK); if (threadIdx.x == 0) { max_ptr[offset] = block_max_value; diff --git a/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu index d94d4395b3..e969b75773 100644 --- a/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu @@ -68,7 +68,7 @@ __global__ void merge_layernorm_v2(T *out, } } - mean = phi::funcs::blockReduceSum(sum, FINAL_MASK); + mean = phi::funcs::BlockReduceSum(sum, FINAL_MASK); if (tid == 0) { s_mean = mean / n; } @@ -84,7 +84,7 @@ __global__ void merge_layernorm_v2(T *out, } } - variance = phi::funcs::blockReduceSum(var, FINAL_MASK); + variance = phi::funcs::BlockReduceSum(var, FINAL_MASK); if (tid == 0) { s_variance = rsqrtf(variance / n + layernorm_eps); } diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu index 5db53958e8..945401c940 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu @@ -26,6 +26,7 @@ #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" namespace paddle { namespace inference { @@ -33,41 +34,6 @@ namespace tensorrt { namespace plugin { #ifdef TRT_PLUGIN_FP16_AVALIABLE #define FINAL_MASK 0xffffffff -template -__inline__ __device__ T warpReduceSumV2(T *val) { -#pragma unroll - for (int i = 0; i < NUM; i++) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } - return (T)(0.0f); -} - -template -__inline__ __device__ T blockReduceSumV2(T *val) { - static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduceSumV2(val); - - if (lane == 0) { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[i][wid] = val[i]; - } - } - __syncthreads(); - - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[i][lane] : (T)(0.0f); - } - warpReduceSumV2(val); - return (T)0.0f; -} template __global__ void generalAddBiasResidualLayerNormOpt2( @@ -119,7 +85,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2( float sums[2]; sums[0] = x_sum; sums[1] = x2_sum; - blockReduceSumV2(sums); + phi::funcs::BlockReduceSumV2(sums); if (threadIdx.x == 0) { s_mean = sums[0] / n / 2; diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu index 09c971a858..71c8292cf7 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu @@ -70,7 +70,7 @@ __global__ void merge_layernorm_v2(T *out, } } - mean = phi::funcs::blockReduceSum(sum, FINAL_MASK); + mean = phi::funcs::BlockReduceSum(sum, FINAL_MASK); if (tid == 0) { s_mean = mean / n; } @@ -86,7 +86,7 @@ __global__ void merge_layernorm_v2(T *out, } } - variance = phi::funcs::blockReduceSum(var, FINAL_MASK); + variance = phi::funcs::BlockReduceSum(var, FINAL_MASK); if (tid == 0) { s_variance = rsqrtf(variance / n + layernorm_eps); } diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index a97ab99dc2..8c5225edaf 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -269,10 +269,10 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, ? static_cast(qk_buf_[threadIdx.x + qk_offset] + bias_qk_[threadIdx.x + qk_offset]) : -1e20f; - float max_val = phi::funcs::blockReduceMax(tmp, mask); + float max_val = phi::funcs::BlockReduceMax(tmp, mask); float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f; - float sum_val = phi::funcs::blockReduceSum(qk_tmp, mask); + float sum_val = phi::funcs::BlockReduceSum(qk_tmp, mask); if (threadIdx.x < seq_len) qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val); @@ -295,10 +295,10 @@ __global__ void SoftmaxKernelWithEltadd(half *qk_buf_, ? static_cast(qk_buf_[threadIdx.x + qk_offset] + bias_qk_[threadIdx.x + qk_offset]) : -1e20f; - float max_val = phi::funcs::blockReduceMax(tmp, mask); + float max_val = phi::funcs::BlockReduceMax(tmp, mask); float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f; - float sum_val = phi::funcs::blockReduceSum(qk_tmp, mask); + float sum_val = phi::funcs::BlockReduceSum(qk_tmp, mask); if (threadIdx.x < seq_len) qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val); @@ -321,12 +321,12 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, ? phi::funcs::ToFloat2(qk_buf_[idx + qk_offset] + bias_qk_[idx + qk_offset]) : make_float2(-1e20f, -1e20f); - float max_val = phi::funcs::blockReduceMax(max(tmp.x, tmp.y), mask); + float max_val = phi::funcs::BlockReduceMax(max(tmp.x, tmp.y), mask); float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val), __expf(tmp.y - max_val)) : make_float2(0.f, 0.f); float sum_val = - phi::funcs::blockReduceSum(qk_tmp.x + qk_tmp.y, mask) + 1e-6f; + phi::funcs::BlockReduceSum(qk_tmp.x + qk_tmp.y, mask) + 1e-6f; if (idx < seq_len) { qk_buf_[idx + qk_offset] = @@ -353,12 +353,12 @@ __global__ void SoftmaxKernelWithEltadd2(half2 *qk_buf_, ? phi::funcs::ToFloat2(qk_buf_[idx + qk_offset] + bias_qk_[idx + qk_offset]) : make_float2(-1e20f, -1e20f); - float max_val = phi::funcs::blockReduceMax(max(tmp.x, tmp.y), mask); + float max_val = phi::funcs::BlockReduceMax(max(tmp.x, tmp.y), mask); float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val), __expf(tmp.y - max_val)) : make_float2(0.f, 0.f); float sum_val = - phi::funcs::blockReduceSum(qk_tmp.x + qk_tmp.y, mask) + 1e-6f; + phi::funcs::BlockReduceSum(qk_tmp.x + qk_tmp.y, mask) + 1e-6f; if (idx < seq_len) { qk_buf_[idx + qk_offset] = @@ -386,14 +386,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf, bias_qk[threadIdx.x + i + qk_offset] : stride_max; } - T max_val = phi::funcs::blockReduceMax(stride_max, mask); + T max_val = phi::funcs::BlockReduceMax(stride_max, mask); T stride_sum = 0.f; for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] + bias_qk[threadIdx.x + i + qk_offset] - max_val); } - T sum_val = phi::funcs::blockReduceSum(stride_sum, mask); + T sum_val = phi::funcs::BlockReduceSum(stride_sum, mask); for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { qk_buf[threadIdx.x + i + qk_offset] = @@ -422,7 +422,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, bias_qk[threadIdx.x + i + qk_offset]); stride_max = tmp > stride_max ? tmp : stride_max; } - float max_val = phi::funcs::blockReduceMax(stride_max, mask); + float max_val = phi::funcs::BlockReduceMax(stride_max, mask); float stride_sum = 0.f; for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -430,7 +430,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf, bias_qk[threadIdx.x + i + qk_offset]); stride_sum += __expf(tmp - max_val); } - float sum_val = phi::funcs::blockReduceSum(stride_sum, mask); + float sum_val = phi::funcs::BlockReduceSum(stride_sum, mask); for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { float tmp = @@ -461,7 +461,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, stride_max.y = max(stride_max.y, cur.y); } float max_val = - phi::funcs::blockReduceMax(max(stride_max.x, stride_max.y), mask); + phi::funcs::BlockReduceMax(max(stride_max.x, stride_max.y), mask); float2 stride_sum = make_float2(0.f, 0.f); for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -472,7 +472,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_, } float sum_val = - phi::funcs::blockReduceSum(stride_sum.x + stride_sum.y, mask) + + phi::funcs::BlockReduceSum(stride_sum.x + stride_sum.y, mask) + 1e-6f; for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -507,7 +507,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_, stride_max.y = max(stride_max.y, cur.y); } float max_val = - phi::funcs::blockReduceMax(max(stride_max.x, stride_max.y), mask); + phi::funcs::BlockReduceMax(max(stride_max.x, stride_max.y), mask); float2 stride_sum = make_float2(0.f, 0.f); for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -519,7 +519,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_, } float sum_val = - phi::funcs::blockReduceSum(stride_sum.x + stride_sum.y, mask) + + phi::funcs::BlockReduceSum(stride_sum.x + stride_sum.y, mask) + 1e-6f; for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) { @@ -573,83 +573,6 @@ inline __device__ T hadd2(T a, T b) { return __hadd2(a, b); } -template -__inline__ __device__ T warpReduceSumV2(T *val) { -#pragma unroll - for (int i = 0; i < NUM; i++) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } - return (T)(0.0f); -} - -template -__inline__ __device__ T blockReduceSumV2(T *val) { - static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduceSumV2(val); - - if (lane == 0) { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[i][wid] = val[i]; - } - } - - __syncthreads(); - - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[i][lane] : (T)(0.0f); - } - warpReduceSumV2(val); - return (T)0.0f; -} - -template -__inline__ __device__ T warpReduceMaxV2(T *val) { -#pragma unroll - for (int i = 0; i < NUM; i++) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); - } - return (T)(0.0f); -} - -template -__inline__ __device__ T blockReduceMaxV2(T *val) { - static __shared__ T shared[32][NUM]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - warpReduceMaxV2(val); // get maxx in each warp - - if (lane == 0) { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[wid][i] = val[i]; - } - } - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[lane][i] : (T)-1e20f; - } - warpReduceMaxV2(val); - - return (T)0.0f; -} - template __global__ void softmax_kernel_with_mask(T *qk_buf_, const T *attr_mask, @@ -715,9 +638,9 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, } if (blockDim.x <= 32) { - warpReduceMaxV2(local_max); + phi::funcs::WarpReduceMaxV2(local_max); } else { - blockReduceMaxV2(local_max); + phi::funcs::BlockReduceMaxV2(local_max); } if (threadIdx.x == 0) { @@ -750,9 +673,9 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, } if (blockDim.x <= 32) { - warpReduceSumV2(local_sum); + phi::funcs::WarpReduceSumV2(local_sum); } else { - blockReduceSumV2(local_sum); + phi::funcs::BlockReduceSumV2(local_sum); } if (threadIdx.x == 0) { diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 43ac1532d4..c91752fef0 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -187,8 +187,8 @@ __global__ void L2NormKernel( g_tmp += (tmp1 * tmp1); tid += grid_stride; } - p_tmp = phi::funcs::blockReduceSum(p_tmp, FINAL_MASK); - g_tmp = phi::funcs::blockReduceSum(g_tmp, FINAL_MASK); + p_tmp = phi::funcs::BlockReduceSum(p_tmp, FINAL_MASK); + g_tmp = phi::funcs::BlockReduceSum(g_tmp, FINAL_MASK); if (threadIdx.x == 0) { p_buffer[blockIdx.x] = p_tmp; @@ -198,8 +198,8 @@ __global__ void L2NormKernel( cg->sync(); // Grid sync for writring partial result to gloabl memory MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; - MT tmp0 = phi::funcs::blockReduceSum(p_part_sum, FINAL_MASK); - MT tmp1 = phi::funcs::blockReduceSum(g_part_sum, FINAL_MASK); + MT tmp0 = phi::funcs::BlockReduceSum(p_part_sum, FINAL_MASK); + MT tmp1 = phi::funcs::BlockReduceSum(g_part_sum, FINAL_MASK); if (threadIdx.x == 0) { s_buffer[0] = tmp0; s_buffer[1] = tmp1; @@ -393,8 +393,8 @@ __global__ void MomentumLarsKernel(const T* param, MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; __syncthreads(); MT param_norm = - Sqrt(phi::funcs::blockReduceSum(param_part_norm, FINAL_MASK)); - MT grad_norm = Sqrt(rescale_grad_pow * phi::funcs::blockReduceSum( + Sqrt(phi::funcs::BlockReduceSum(param_part_norm, FINAL_MASK)); + MT grad_norm = Sqrt(rescale_grad_pow * phi::funcs::BlockReduceSum( grad_part_norm, FINAL_MASK)); #endif MomentumUpdate(param, diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index f7bbf7ad0b..b493e2ac41 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -168,7 +168,7 @@ struct KeyValuePair { #define WARP_SIZE 32 template -__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) { +__inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val += __shfl_xor_sync(lane_mask, val, mask, warpSize); @@ -180,12 +180,12 @@ __inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) { /* Calculate the sum of all elements in a block */ template -__inline__ __device__ T blockReduceSum(T val, unsigned mask) { +__inline__ __device__ T BlockReduceSum(T val, unsigned mask) { static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - val = warpReduceSum(val, mask); + val = WarpReduceSum(val, mask); __syncthreads(); if (lane == 0) shared[wid] = val; @@ -195,13 +195,53 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; val = (lane < block_span) ? shared[lane] : static_cast(0.0f); - val = warpReduceSum(val, mask); + val = WarpReduceSum(val, mask); return val; } +/* +WarpReduce multi values. +*/ +template +__inline__ __device__ T WarpReduceSumV2(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T BlockReduceSumV2(T *val) { + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + WarpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + WarpReduceSumV2(val); + return (T)0.0f; +} + template -__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) { +__inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); @@ -211,8 +251,19 @@ __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) { return val; } +template +__inline__ __device__ T WarpReduceMaxV2(T *val) { +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T)(0.0f); +} + template -__inline__ __device__ T warpReduceMin(T val, unsigned lane_mask) { +__inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); @@ -246,12 +297,12 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) { /* Calculate the maximum of all elements in a block */ template -__inline__ __device__ T blockReduceMax(T val, unsigned mask) { +__inline__ __device__ T BlockReduceMax(T val, unsigned mask) { static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - val = warpReduceMax(val, mask); + val = WarpReduceMax(val, mask); if (lane == 0) shared[wid] = val; @@ -260,26 +311,55 @@ __inline__ __device__ T blockReduceMax(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; val = (lane < block_span) ? shared[lane] : -1e10f; - val = warpReduceMax(val, mask); + val = WarpReduceMax(val, mask); return val; } +template +__inline__ __device__ T BlockReduceMaxV2(T *val) { + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + WarpReduceMaxV2(val); // get maxx in each warp + + if (lane == 0) { // record in-warp maxx by warp Idx +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[lane][i] : (T)-1e20f; + } + WarpReduceMaxV2(val); + + return (T)0.0f; +} + /* Calculate the minimum of all elements in a block */ template -__inline__ __device__ T blockReduceMin(T val, unsigned mask) { +__inline__ __device__ T BlockReduceMin(T val, unsigned mask) { static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - val = warpReduceMin(val, mask); + val = WarpReduceMin(val, mask); if (lane == 0) shared[wid] = val; __syncthreads(); // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; val = (lane < block_span) ? shared[lane] : 1e10f; - val = warpReduceMin(val, mask); + val = WarpReduceMin(val, mask); return val; } diff --git a/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu b/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu index dcdb30ed9a..a0c6719c3c 100644 --- a/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu +++ b/paddle/phi/kernels/fusion/cutlass/moe_kernel.cu @@ -160,30 +160,17 @@ void InitExpertChoiceRouteKernelLauncher( <<>>(reinterpret_cast(buffer), \ (const half*)attr_mask, \ batch_size, \ - head_num, \ - seq_len_1, \ - seq_len_2, \ - (const half)scalar); \ + seq_len); \ } else { \ softmax_kernel_v4_half2<__half, ITEMS_PER_THREAD> \ <<>>(reinterpret_cast(buffer), \ (const half*)attr_mask, \ batch_size, \ - head_num, \ - seq_len_1, \ - seq_len_2, \ - (const half)scalar); \ + seq_len); \ } \ } else { \ - softmax_kernel_v4 \ - <<>>(buffer, \ - buffer_src, \ - attr_mask, \ - batch_size, \ - head_num, \ - seq_len_1, \ - seq_len_2, \ - scalar); \ + softmax_kernel_v4<<>>( \ + buffer, buffer_src, attr_mask, batch_size, seq_len); \ } template @@ -191,19 +178,16 @@ void invokeMaskedSoftMax(T* buffer, const T* buffer_src, const T* attr_mask, const int batch_size, - const int seq_len_1, - const int seq_len_2, - const int head_num, - const T scalar, + const int seq_len, cudaStream_t stream) { - // NOTE: attention scores shape (batch_size, head_num, seq_len_1, seq_len_2) - dim3 grid(seq_len_1, batch_size, head_num); - if (batch_size * head_num > 360) { - grid.x = ceil(static_cast(seq_len_1) / 32.0f); + // NOTE: attention scores shape (batch_size, seq_len) + dim3 grid(1, batch_size, 1); + if (batch_size > 360) { + grid.x = ceil(static_cast(1) / 32.0f); } - bool is_half2 = sizeof(T) == 2 && sizeof(T) == 2 && seq_len_2 % 2 == 0; - dim3 block((seq_len_2 / (is_half2 ? 2 : 1) + 31) / 32 * 32); + bool is_half2 = sizeof(T) == 2 && sizeof(T) == 2 && seq_len % 2 == 0; + dim3 block((seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32); if (block.x > 2048 && block.x <= 4096) { SOFTMAX_KERNEL(4) @@ -766,26 +750,19 @@ void MoeKernel(const Context& ctx, k, batch_size, ctx.stream()); - T scalar = (T)1.0f; if (IS_FP16) { invokeMaskedSoftMax<__half>(reinterpret_cast<__half*>(gating_output), reinterpret_cast(gating_output), reinterpret_cast(attr_mask), /*batch_size=*/num_rows, - /*seq_len_1=*/1, - /*seq_len_2=*/num_experts, - /*head_num=*/1, - *reinterpret_cast(&scalar), + /*seq_len=*/num_experts, ctx.stream()); } else { invokeMaskedSoftMax(reinterpret_cast(gating_output), reinterpret_cast(gating_output), reinterpret_cast(attr_mask), /*batch_size=*/num_rows, - /*seq_len_1=*/1, - /*seq_len_2=*/num_experts, - /*head_num=*/1, - *reinterpret_cast(&scalar), + /*seq_len=*/num_experts, ctx.stream()); } InvokeTransposeAxis01( diff --git a/paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h b/paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h index 73fc242c66..7d63e74fb9 100644 --- a/paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h +++ b/paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h @@ -26,87 +26,6 @@ static inline size_t AlignTo16(const size_t& input) { return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); } -/* -WarpReduce multi values. -TODO(zhengzekang): Add blocksize templates to reduce shared memory usage. -*/ -template -__inline__ __device__ T warpReduceSumV2(T* val) { -#pragma unroll - for (int i = 0; i < NUM; i++) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } - return (T)(0.0f); -} - -template -__inline__ __device__ T blockReduceSumV2(T* val) { - static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduceSumV2(val); - - if (lane == 0) { -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[i][wid] = val[i]; - } - } - - __syncthreads(); - - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[i][lane] : (T)(0.0f); - } - warpReduceSumV2(val); - return (T)0.0f; -} - -template -__inline__ __device__ T warpReduceMaxV2(T* val) { -#pragma unroll - for (int i = 0; i < NUM; i++) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); - } - return (T)(0.0f); -} - -template -__inline__ __device__ T blockReduceMaxV2(T* val) { - static __shared__ T shared[32][NUM]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - warpReduceMaxV2(val); // get maxx in each warp - - if (lane == 0) { // record in-warp maxx by warp Idx -#pragma unroll - for (int i = 0; i < NUM; i++) { - shared[wid][i] = val[i]; - } - } - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[lane][i] : (T)-1e20f; - } - warpReduceMaxV2(val); - - return (T)0.0f; -} - class CubKeyValueSorter { public: CubKeyValueSorter(); @@ -311,65 +230,57 @@ __global__ void initialize_expert_choice_route_kernel( template __global__ void softmax_kernel_v4( T* qk_buf_, - const T* qk_buf_src, // shape [batch_size, head_num, seq_len_1, seq_len_2] - const T* attr_mask, // shape [batch_size, seq_len_1, seq_len_2] + const T* qk_buf_src, // shape [batch_size, seq_len] + const T* attr_mask, // shape [batch_size, seq_len] const int batch_size, - const int head_num, - const int seq_len_1, - const int seq_len_2, - const T scalar) { + const int seq_len) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) { - float data[ITEMS_PER_THREAD]; - int qk_offset; - __shared__ float s_mean, s_max; - float local_max = -1e20f; - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { - qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * - seq_len_2 + - blockDim.x * i + threadIdx.x; - int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * seq_len_2 + - blockDim.x * i + threadIdx.x; - - float qk = static_cast(qk_buf_src[qk_offset]); - float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); - - mask_val = (1.0f - mask_val) * -10000.0f; - - data[i] = qk * static_cast(scalar) + mask_val; - local_max = fmax(local_max, data[i]); - } + float data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y) * seq_len + blockDim.x * i + threadIdx.x; + + float qk = static_cast(qk_buf_src[qk_offset]); + float mask_val = static_cast(__ldg(&attr_mask[mask_offset])); + + mask_val = (1.0f - mask_val) * -10000.0f; + + data[i] = qk + mask_val; + local_max = fmax(local_max, data[i]); + } - float max_val = - blockDim.x <= 32 - ? phi::funcs::warpReduceMax(local_max, 0xFFFFFFFF) - : phi::funcs::blockReduceMax(local_max, 0xffffffff); - if (threadIdx.x == 0) { - s_max = max_val; - } - __syncthreads(); + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); - float local_sum = 0; - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { - data[i] = __expf(data[i] - s_max); - local_sum += data[i]; - } - float sum_val = - blockDim.x <= 32 - ? phi::funcs::warpReduceSum(local_sum, 0xffffffff) - : phi::funcs::blockReduceSum(local_sum, 0xffffffff); - if (threadIdx.x == 0) { - s_mean = sum_val + 1e-6f; - s_mean = __fdividef(1.0f, s_mean); - } - __syncthreads(); + float local_sum = 0; + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + data[i] = __expf(data[i] - s_max); + local_sum += data[i]; + } + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); - for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) { - qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * - seq_len_2 + - blockDim.x * i + threadIdx.x; - qk_buf_[qk_offset] = (T)(data[i] * s_mean); - } + for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) { + qk_offset = + ((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x; + qk_buf_[qk_offset] = (T)(data[i] * s_mean); } #endif } @@ -378,77 +289,69 @@ template __global__ void softmax_kernel_v4_half2(T* qk_buf_, const T* attr_mask, const int batch_size, - const int head_num, - const int seq_len_1, - const int seq_len_2, - const T scalar) { + const int seq_len) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 using T2 = half2; T2* qk_buf_half2 = reinterpret_cast(qk_buf_); const T2* attr_mask_half2 = (const T2*)attr_mask; - for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) { - T2 data[ITEMS_PER_THREAD]; - int qk_offset; - __shared__ float s_mean, s_max; - float local_max = -1e20f; - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; - i++) { - qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * - (seq_len_2 / 2) + - blockDim.x * i + threadIdx.x; - int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * (seq_len_2 / 2) + - blockDim.x * i + threadIdx.x; - - T2 qk = qk_buf_half2[qk_offset]; - T2 mask_val = __ldg(&attr_mask_half2[mask_offset]); - mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), - __float2half2_rn(-10000.0f)); - - data[i] = __hadd2(__hmul2(qk, __half2half2(scalar)), mask_val); - - local_max = fmax( - local_max, - fmax(static_cast(data[i].x), static_cast(data[i].y))); - } + T2 data[ITEMS_PER_THREAD]; + int qk_offset; + __shared__ float s_mean, s_max; + float local_max = -1e20f; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + int mask_offset = blockIdx.y * (seq_len / 2) + blockDim.x * i + threadIdx.x; + + T2 qk = qk_buf_half2[qk_offset]; + T2 mask_val = __ldg(&attr_mask_half2[mask_offset]); + mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val), + __float2half2_rn(-10000.0f)); + + data[i] = __hadd2(qk, mask_val); + + local_max = fmax( + local_max, + fmax(static_cast(data[i].x), static_cast(data[i].y))); + } - float max_val = - blockDim.x <= 32 - ? phi::funcs::warpReduceMax(local_max, 0xFFFFFFFF) - : phi::funcs::blockReduceMax(local_max, 0xFFFFFFFF); - if (threadIdx.x == 0) { - s_max = max_val; - } - __syncthreads(); - - float local_sum = 0; - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; - i++) { - data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max))); - local_sum += static_cast(data[i].x + data[i].y); - } + float max_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceMax(local_max, 0xFFFFFFFF) + : phi::funcs::BlockReduceMax(local_max, 0xFFFFFFFF); + if (threadIdx.x == 0) { + s_max = max_val; + } + __syncthreads(); - float sum_val = - blockDim.x <= 32 - ? phi::funcs::warpReduceSum(local_sum, 0xFFFFFFFF) - : phi::funcs::blockReduceSum(local_sum, 0xFFFFFFFF); + float local_sum = 0; + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max))); + local_sum += static_cast(data[i].x + data[i].y); + } - if (threadIdx.x == 0) { - s_mean = sum_val + 1e-6f; - s_mean = __fdividef(1.0f, s_mean); - } - __syncthreads(); - - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; - i++) { - qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) * - (seq_len_2 / 2) + - blockDim.x * i + threadIdx.x; - qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); - } + float sum_val = + blockDim.x <= 32 + ? phi::funcs::WarpReduceSum(local_sum, 0xFFFFFFFF) + : phi::funcs::BlockReduceSum(local_sum, 0xFFFFFFFF); + + if (threadIdx.x == 0) { + s_mean = sum_val + 1e-6f; + s_mean = __fdividef(1.0f, s_mean); + } + __syncthreads(); + + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i + + threadIdx.x; + qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean)); } #endif } @@ -457,131 +360,123 @@ template __global__ void softmax_kernel_v5_half2(T* qk_buf_, const T* attr_mask, const int batch_size, - const int head_num, - const int seq_len_1, - const int seq_len_2, - const T scalar) { + const int seq_len) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 using T2 = half2; T2* qk_buf_half2 = reinterpret_cast(qk_buf_); const T2* attr_mask_half2 = (const T2*)attr_mask; - for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x * NUM) { - T2 data[NUM][ITEMS_PER_THREAD]; + T2 data[NUM][ITEMS_PER_THREAD]; - int qk_offset[NUM]; + int qk_offset[NUM]; - __shared__ float s_sum[NUM], s_max[NUM]; - float local_max[NUM]; + __shared__ float s_sum[NUM], s_max[NUM]; + float local_max[NUM]; #pragma unroll - for (int j = 0; j < NUM; j++) { - local_max[j] = -1e20f; - } + for (int j = 0; j < NUM; j++) { + local_max[j] = -1e20f; + } - const int MAX_NUM = - min((seq_len_1 - seq_id + gridDim.x - 1) / gridDim.x, NUM); - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; - i++) { - int mask_offset[NUM]; + const int MAX_NUM = min((1 + gridDim.x - 1) / gridDim.x, NUM); + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { + int mask_offset[NUM]; #pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + - seq_id + j * gridDim.x) * - (seq_len_2 / 2) + + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + mask_offset[j] = (blockIdx.y + j * gridDim.x) * (seq_len / 2) + blockDim.x * i + threadIdx.x; - mask_offset[j] = (blockIdx.y * seq_len_1 + seq_id + j * gridDim.x) * - (seq_len_2 / 2) + - blockDim.x * i + threadIdx.x; - } + } - T2 mask_val[NUM]; + T2 mask_val[NUM]; #pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]); - } + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]); + } - T2 qk[NUM]; -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk[j] = qk_buf_half2[qk_offset[j]]; - } + T2 qk[NUM]; #pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]), - __float2half2_rn(-10000.0f)); - } -#pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - data[j][i] = __hadd2(__hmul2(qk[j], __half2half2(scalar)), mask_val[j]); - local_max[j] = fmax(local_max[j], - fmax(static_cast(data[j][i].x), - static_cast(data[j][i].y))); - } + for (int j = 0; j < MAX_NUM; j++) { + qk[j] = qk_buf_half2[qk_offset[j]]; } - if (blockDim.x <= 32) { - warpReduceMaxV2(local_max); - } else { - blockReduceMaxV2(local_max); +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]), + __float2half2_rn(-10000.0f)); } - - if (threadIdx.x == 0) { #pragma unroll - for (int j = 0; j < NUM; j++) { - s_max[j] = local_max[j]; - } + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = __hadd2(qk[j], mask_val[j]); + local_max[j] = fmax(local_max[j], + fmax(static_cast(data[j][i].x), + static_cast(data[j][i].y))); } - __syncthreads(); - float local_sum[NUM]; + } + if (blockDim.x <= 32) { + phi::funcs::WarpReduceMaxV2(local_max); + } else { + phi::funcs::BlockReduceMaxV2(local_max); + } + + if (threadIdx.x == 0) { #pragma unroll for (int j = 0; j < NUM; j++) { - local_sum[j] = {0.f}; + s_max[j] = local_max[j]; } - - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; - i++) { + } + __syncthreads(); + float local_sum[NUM]; #pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j]))); - } + for (int j = 0; j < NUM; j++) { + local_sum[j] = {0.f}; + } + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { #pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - local_sum[j] += static_cast(data[j][i].x + data[j][i].y); - } + for (int j = 0; j < MAX_NUM; j++) { + data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j]))); } - if (blockDim.x <= 32) { - warpReduceSumV2(local_sum); - } else { - blockReduceSumV2(local_sum); +#pragma unroll + for (int j = 0; j < MAX_NUM; j++) { + local_sum[j] += static_cast(data[j][i].x + data[j][i].y); } + } - if (threadIdx.x == 0) { + if (blockDim.x <= 32) { + phi::funcs::WarpReduceSumV2(local_sum); + + } else { + phi::funcs::BlockReduceSumV2(local_sum); + } + + if (threadIdx.x == 0) { #pragma unroll - for (int j = 0; j < NUM; j++) { - s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); - } + for (int j = 0; j < NUM; j++) { + s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); } - __syncthreads(); + } + __syncthreads(); - for (int i = 0; - blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD; - i++) { + for (int i = 0; + blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; + i++) { #pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + - seq_id + j * gridDim.x) * - (seq_len_2 / 2) + - blockDim.x * i + threadIdx.x; - } + for (int j = 0; j < MAX_NUM; j++) { + qk_offset[j] = + ((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) + + blockDim.x * i + threadIdx.x; + } #pragma unroll - for (int j = 0; j < MAX_NUM; j++) { - qk_buf_half2[qk_offset[j]] = - __hmul2(data[j][i], __float2half2_rn(s_sum[j])); - } + for (int j = 0; j < MAX_NUM; j++) { + qk_buf_half2[qk_offset[j]] = + __hmul2(data[j][i], __float2half2_rn(s_sum[j])); } } #endif diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index e89a98d49e..5040be8eaa 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -62,7 +62,7 @@ __global__ void ReduceSumWithSubtract( } __syncthreads(); - sum_val = phi::funcs::blockReduceSum(sum_val, FULL_MASK); + sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = sum_val; } @@ -80,7 +80,7 @@ __global__ void ReduceMaxWithSubtract(const T* x, } __syncthreads(); - max_val = phi::funcs::blockReduceMax(max_val, FULL_MASK); + max_val = phi::funcs::BlockReduceMax(max_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = max_val; } @@ -98,7 +98,7 @@ __global__ void ReduceMinWithSubtract(const T* x, } __syncthreads(); - min_val = phi::funcs::blockReduceMin(min_val, FULL_MASK); + min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = min_val; } diff --git a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu index 9a9d2e80ab..57e0d7f3a1 100644 --- a/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/interpolate_grad_kernel.cu @@ -211,7 +211,7 @@ __inline__ __device__ T PartialBlockMin(T val, if (threadIdx.x < threshold) { shared_last_idx = (threshold >> 5) - 1; - val = phi::funcs::warpReduceMin(val, mask); + val = phi::funcs::WarpReduceMin(val, mask); if (lane == 0) { shared[wid] = val; } @@ -226,7 +226,7 @@ __inline__ __device__ T PartialBlockMin(T val, if (threadIdx.x < threshold) { val = (lane <= shared_last_idx) ? shared[lane] : std::numeric_limits::max(); - val = phi::funcs::warpReduceMin(val, mask); + val = phi::funcs::WarpReduceMin(val, mask); shared_last_val = val; } __syncthreads(); @@ -292,13 +292,13 @@ __global__ void KeBilinearInterpBwShareMemory(T* in, s_data[1][threadIdx.x] = static_cast(0); int remain = nthreads - (tid & (-blockDim.x)); int in_top_max_index = - phi::funcs::blockReduceMax(top_right_index, FINAL_MASK); + phi::funcs::BlockReduceMax(top_right_index, FINAL_MASK); int in_bot_max_index = - phi::funcs::blockReduceMax(bot_right_index, FINAL_MASK); + phi::funcs::BlockReduceMax(bot_right_index, FINAL_MASK); if (remain > blockDim.x) { - in_top_min_index = phi::funcs::blockReduceMin(input_index, FINAL_MASK); - in_bot_min_index = phi::funcs::blockReduceMin(bot_left_index, FINAL_MASK); + in_top_min_index = phi::funcs::BlockReduceMin(input_index, FINAL_MASK); + in_bot_min_index = phi::funcs::BlockReduceMin(bot_left_index, FINAL_MASK); } else { in_top_min_index = PartialBlockMin(input_index, remain, FINAL_MASK); in_bot_min_index = PartialBlockMin(bot_left_index, remain, FINAL_MASK); diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu index 4c83203ed0..64c7a5822c 100644 --- a/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu @@ -47,7 +47,7 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows, for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { mul += out_values[row_first + idx] * dout_values[row_first + idx]; } - T mul_sum = phi::funcs::warpReduceSum(mul, 0xFFFFFFFF); + T mul_sum = phi::funcs::WarpReduceSum(mul, 0xFFFFFFFF); for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { dx_values[row_first + idx] = (dout_values[row_first + idx] - mul_sum) * diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu index 04d143fdb3..cd8013b4ee 100644 --- a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu @@ -72,7 +72,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, out_values[row_first + idx] = -std::numeric_limits::infinity(); } } - T row_max_val = phi::funcs::warpReduceMax(max_val, 0xFFFFFFFF); + T row_max_val = phi::funcs::WarpReduceMax(max_val, 0xFFFFFFFF); T exp_sum = 0; for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { @@ -81,7 +81,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, exp_sum += exp; out_values[row_first + idx] = exp; } - T row_exp_sum = phi::funcs::warpReduceSum(exp_sum, 0xFFFFFFFF); + T row_exp_sum = phi::funcs::WarpReduceSum(exp_sum, 0xFFFFFFFF); for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) { out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum; diff --git a/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu index 5a66786ebb..6b040b6992 100644 --- a/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu @@ -53,7 +53,7 @@ __global__ void SoftmaxGradGpuKernel(const IntT* out_crows, mul_result += out_values[row_first + idx] * dout_values[row_first + idx]; } - T sum = phi::funcs::warpReduceSum(mul_result, 0xFFFFFFFF); + T sum = phi::funcs::WarpReduceSum(mul_result, 0xFFFFFFFF); for (int i = 0; i < kIteration; ++i) { int idx = non_zero_idx + i * warpSize; diff --git a/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu b/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu index ef6b6d91e5..ace65355b6 100644 --- a/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/softmax_kernel.cu @@ -57,7 +57,7 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows, max_val = val; } } - T row_max_val = phi::funcs::warpReduceMax(max_val, 0xFFFFFFFF); + T row_max_val = phi::funcs::WarpReduceMax(max_val, 0xFFFFFFFF); T exp_sum = 0; for (int i = 0; i < kIteration; ++i) { @@ -69,7 +69,7 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows, exp_sum += exp; out_values[row_first + idx] = exp; } - T row_exp_sum = phi::funcs::warpReduceSum(exp_sum, 0xFFFFFFFF); + T row_exp_sum = phi::funcs::WarpReduceSum(exp_sum, 0xFFFFFFFF); for (int i = 0; i < kIteration; ++i) { int idx = non_zero_idx + i * warpSize; -- GitLab