未验证 提交 2f0f10b3 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

[cherry-pick] Fix multihead op bug. (#20783) (#21438)

The op should handle k=1024
Fix seq_len < warpsize error.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 873b32de
......@@ -134,7 +134,7 @@ MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of H
Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
Both the input `Q` and `K` can carry the LoD (Level of Details) information,
......
......@@ -28,10 +28,10 @@ namespace operators {
#define WARP_SIZE 32
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize);
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
val += __shfl_xor(val, mask, warpSize);
#endif
......@@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) {
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val) {
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
val = warpReduceSum<T>(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val, mask);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val) {
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize));
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = max(val, __shfl_xor(val, mask, warpSize));
#endif
......@@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) {
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val) {
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMax(val);
val = warpReduceMax(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f;
val = warpReduceMax(val);
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
val = warpReduceMax(val, mask);
return val;
}
......@@ -190,7 +194,8 @@ template <typename T>
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len) {
const int seq_len,
const unsigned mask) {
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
......@@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
bias_qk_[threadIdx.x + bias_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
float max_val = blockReduceMax<float>(tmp);
float max_val = blockReduceMax<float>(tmp, mask);
if (threadIdx.x == 0) s_max = max_val;
__syncthreads();
float qk_tmp =
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp);
float sum_val = blockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x == 0) {
s_sum = sum_val + 1e-6f;
......@@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
int grid = m;
int block = k;
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len);
qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
}
template <typename T>
......@@ -331,7 +339,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx,
auto stream = dev_ctx.stream();
int grid = m;
PADDLE_ENFORCE_LT(k, 1024,
PADDLE_ENFORCE_LE(k, 1024,
"Input head_number * size_per_head should <= 1024");
int block = k <= 1024 ? k : 1024;
add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册