未验证 提交 2f0f10b3 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

[cherry-pick] Fix multihead op bug. (#20783) (#21438)

The op should handle k=1024
Fix seq_len < warpsize error.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 873b32de
...@@ -134,7 +134,7 @@ MultiHeadMatMul Operator. ...@@ -134,7 +134,7 @@ MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model. This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie. Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of H Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
Both the input `Q` and `K` can carry the LoD (Level of Details) information, Both the input `Q` and `K` can carry the LoD (Level of Details) information,
......
...@@ -28,10 +28,10 @@ namespace operators { ...@@ -28,10 +28,10 @@ namespace operators {
#define WARP_SIZE 32 #define WARP_SIZE 32
template <typename T> template <typename T>
__inline__ __device__ T warpReduceSum(T val) { __inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 #if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize); val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else #else
val += __shfl_xor(val, mask, warpSize); val += __shfl_xor(val, mask, warpSize);
#endif #endif
...@@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) { ...@@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) {
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template <typename T> template <typename T>
__inline__ __device__ T blockReduceSum(T val) { __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE]; static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val, mask);
if (lane == 0) shared[wid] = val; if (lane == 0) shared[wid] = val;
__syncthreads(); __syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f); // align block_span to warpSize
val = warpReduceSum<T>(val); int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val, mask);
return val; return val;
} }
template <typename T> template <typename T>
__inline__ __device__ T warpReduceMax(T val) { __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1) for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000 #if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize)); val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else #else
val = max(val, __shfl_xor(val, mask, warpSize)); val = max(val, __shfl_xor(val, mask, warpSize));
#endif #endif
...@@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) { ...@@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) {
/* Calculate the maximum of all elements in a block */ /* Calculate the maximum of all elements in a block */
template <typename T> template <typename T>
__inline__ __device__ T blockReduceMax(T val) { __inline__ __device__ T blockReduceMax(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE]; static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
val = warpReduceMax(val); val = warpReduceMax(val, mask);
if (lane == 0) shared[wid] = val; if (lane == 0) shared[wid] = val;
__syncthreads(); __syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f; // align block_span to warpSize
val = warpReduceMax(val); int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
val = warpReduceMax(val, mask);
return val; return val;
} }
...@@ -190,7 +194,8 @@ template <typename T> ...@@ -190,7 +194,8 @@ template <typename T>
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int batch_size, const int batch_size,
const int head_num, const int head_num,
const int seq_len) { const int seq_len,
const unsigned mask) {
int seq_id = blockIdx.x % seq_len; int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len; int qk_offset = blockIdx.x * seq_len;
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len; int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
...@@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_, ...@@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
bias_qk_[threadIdx.x + bias_offset])) bias_qk_[threadIdx.x + bias_offset]))
: 0.0f; : 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f; float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
float max_val = blockReduceMax<float>(tmp);
float max_val = blockReduceMax<float>(tmp, mask);
if (threadIdx.x == 0) s_max = max_val; if (threadIdx.x == 0) s_max = max_val;
__syncthreads(); __syncthreads();
float qk_tmp = float qk_tmp =
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f; threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp); float sum_val = blockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_sum = sum_val + 1e-6f; s_sum = sum_val + 1e-6f;
...@@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num, ...@@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
int grid = m; int grid = m;
int block = k; int block = k;
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>( softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len); qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
} }
template <typename T> template <typename T>
...@@ -331,7 +339,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx, ...@@ -331,7 +339,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx,
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
int grid = m; int grid = m;
PADDLE_ENFORCE_LT(k, 1024, PADDLE_ENFORCE_LE(k, 1024,
"Input head_number * size_per_head should <= 1024"); "Input head_number * size_per_head should <= 1024");
int block = k <= 1024 ? k : 1024; int block = k <= 1024 ? k : 1024;
add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q, add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册