diff --git a/paddle/fluid/operators/multihead_matmul_op.cc b/paddle/fluid/operators/multihead_matmul_op.cc index b612be02b4f50ff1d50c6d8a3e1e0c5c1e9f61c6..fbf372ba6e15aca7b849a8696ac5551dc383ee51 100644 --- a/paddle/fluid/operators/multihead_matmul_op.cc +++ b/paddle/fluid/operators/multihead_matmul_op.cc @@ -134,7 +134,7 @@ MultiHeadMatMul Operator. This op is used for optimize multi head calculation in ernie model. Not suggest to use in other case except has same structure as ernie. -Example of matrix multiplication with head_number of H +Example of matrix multiplication with head_number of B - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] Both the input `Q` and `K` can carry the LoD (Level of Details) information, diff --git a/paddle/fluid/operators/multihead_matmul_op.cu b/paddle/fluid/operators/multihead_matmul_op.cu index 6e8aa712fbf00355b83bde5313ba0d04724e2ffb..8728fd9d21db6a13ee98e46ea331221b88a6d813 100644 --- a/paddle/fluid/operators/multihead_matmul_op.cu +++ b/paddle/fluid/operators/multihead_matmul_op.cu @@ -28,10 +28,10 @@ namespace operators { #define WARP_SIZE 32 template -__inline__ __device__ T warpReduceSum(T val) { +__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 - val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize); + val += __shfl_xor_sync(lane_mask, val, mask, warpSize); #else val += __shfl_xor(val, mask, warpSize); #endif @@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) { /* Calculate the sum of all elements in a block */ template -__inline__ __device__ T blockReduceSum(T val) { +__inline__ __device__ T blockReduceSum(T val, unsigned mask) { static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - val = warpReduceSum(val); + val = warpReduceSum(val, mask); if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); + // align block_span to warpSize + int block_span = (blockDim.x + warpSize - 1) >> 5; + val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val, mask); return val; } template -__inline__ __device__ T warpReduceMax(T val) { +__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 - val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize)); + val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); #else val = max(val, __shfl_xor(val, mask, warpSize)); #endif @@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) { /* Calculate the maximum of all elements in a block */ template -__inline__ __device__ T blockReduceMax(T val) { +__inline__ __device__ T blockReduceMax(T val, unsigned mask) { static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; - val = warpReduceMax(val); + val = warpReduceMax(val, mask); if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f; - val = warpReduceMax(val); + // align block_span to warpSize + int block_span = (blockDim.x + warpSize - 1) >> 5; + val = (threadIdx.x < block_span) ? shared[lane] : -1e10f; + val = warpReduceMax(val, mask); return val; } @@ -190,7 +194,8 @@ template __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, const int batch_size, const int head_num, - const int seq_len) { + const int seq_len, + const unsigned mask) { int seq_id = blockIdx.x % seq_len; int qk_offset = blockIdx.x * seq_len; int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len; @@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, bias_qk_[threadIdx.x + bias_offset])) : 0.0f; float tmp = threadIdx.x < seq_len ? static_cast(qk) : -1e20f; - float max_val = blockReduceMax(tmp); + + float max_val = blockReduceMax(tmp, mask); + if (threadIdx.x == 0) s_max = max_val; __syncthreads(); float qk_tmp = threadIdx.x < seq_len ? __expf(static_cast(tmp - s_max)) : 0.0f; - float sum_val = blockReduceSum(qk_tmp); + float sum_val = blockReduceSum(qk_tmp, mask); if (threadIdx.x == 0) { s_sum = sum_val + 1e-6f; @@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num, int grid = m; int block = k; + unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK; softmax_kernel_with_eltadd<<>>( - qk_buf_, bias_qk, batch_size, head_num, seq_len); + qk_buf_, bias_qk, batch_size, head_num, seq_len, mask); } template @@ -331,7 +339,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx, auto stream = dev_ctx.stream(); int grid = m; - PADDLE_ENFORCE_LT(k, 1024, + PADDLE_ENFORCE_LE(k, 1024, "Input head_number * size_per_head should <= 1024"); int block = k <= 1024 ? k : 1024; add_QKV<<>>(Q, K, V, q_buf, k_buf, v_buf, bias_q,