diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index b0b347123047ac8f9583f5188d52e4177392e27e..8728fd9d21db6a13ee98e46ea331221b88a6d813 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -28,10 +28,10 @@ namespace operators { #define WARP_SIZE 32 template -__inline__ __device__ T warpReduceSum(T val) { +__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 - val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize); + val += __shfl_xor_sync(lane_mask, val, mask, warpSize); #else val += __shfl_xor(val, mask, warpSize); #endif @@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) { /* Calculate the sum of all elements in a block */ template -__inline__ __device__ T blockReduceSum(T val) { +__inline__ __device__ T blockReduceSum(T val, unsigned mask) { static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - val = warpReduceSum(val); + val = warpReduceSum(val, mask); if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); + // align block_span to warpSize + int block_span = (blockDim.x + warpSize - 1) >> 5; + val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val, mask); return val; } template -__inline__ __device__ T warpReduceMax(T val) { +__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 - val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize)); + val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); #else val = max(val, __shfl_xor(val, mask, warpSize)); #endif @@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) { /* Calculate the maximum of all elements in a block */ template -__inline__ __device__ T blockReduceMax(T val) { +__inline__ __device__ T blockReduceMax(T val, unsigned mask) { static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - val = warpReduceMax(val); + val = warpReduceMax(val, mask); if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f; - val = warpReduceMax(val); + // align block_span to warpSize + int block_span = (blockDim.x + warpSize - 1) >> 5; + val = (threadIdx.x < block_span) ? shared[lane] : -1e10f; + val = warpReduceMax(val, mask); return val; } @@ -190,7 +194,8 @@ template __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, const int batch_size, const int head_num, - const int seq_len) { + const int seq_len, + const unsigned mask) { int seq_id = blockIdx.x % seq_len; int qk_offset = blockIdx.x * seq_len; int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len; @@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, bias_qk_[threadIdx.x + bias_offset])) : 0.0f; float tmp = threadIdx.x < seq_len ? static_cast(qk) : -1e20f; - float max_val = blockReduceMax(tmp); + + float max_val = blockReduceMax(tmp, mask); + if (threadIdx.x == 0) s_max = max_val; __syncthreads(); float qk_tmp = threadIdx.x < seq_len ? __expf(static_cast(tmp - s_max)) : 0.0f; - float sum_val = blockReduceSum(qk_tmp); + float sum_val = blockReduceSum(qk_tmp, mask); if (threadIdx.x == 0) { s_sum = sum_val + 1e-6f; @@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num, int grid = m; int block = k; + unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK; softmax_kernel_with_eltadd<<>>( - qk_buf_, bias_qk, batch_size, head_num, seq_len); + qk_buf_, bias_qk, batch_size, head_num, seq_len, mask); } template