From 42f35841a890f61781d0cdf26f709583ca7db4b3 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Fri, 18 Nov 2022 11:17:26 +0800 Subject: [PATCH] fix: supoort huge length of attention (#48053) --- .../operators/math/bert_encoder_functor.cu | 65 +++++++++++++------ 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/math/bert_encoder_functor.cu b/paddle/fluid/operators/math/bert_encoder_functor.cu index 5b11eee61a..a97ab99dc2 100644 --- a/paddle/fluid/operators/math/bert_encoder_functor.cu +++ b/paddle/fluid/operators/math/bert_encoder_functor.cu @@ -783,6 +783,19 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, } } +#define SOFTMAX_KERNEL_WITH_MASK(REPEAT_THREAD) \ + do { \ + block.x /= REPEAT_THREAD; \ + grid.x /= 4; \ + constexpr int NUM = 4; \ + softmax_kernel_with_mask \ + <<>>(reinterpret_cast(qk_buf_), \ + (const half *)bias_qk, \ + batch_size, \ + head_num, \ + seq_len); \ + } while (0) + template inline void MatMulWithHeadQK(const phi::GPUContext &context, int head_num, @@ -843,22 +856,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, "QK_bias is mask can't be supported on rocm or " "cuda_arch<700")); #else - constexpr int ITEMS_PER_THREAD = 1; - bool is_half2 = true; - dim3 grid(seq_len, batch_size, head_num); dim3 block((seq_len / 2 + 31) / 32 * 32); - block.x /= ITEMS_PER_THREAD; - assert(block.x <= 1024); - assert(grid.x % 4 == 0); - grid.x /= 4; - constexpr int NUM = 4; - softmax_kernel_with_mask - <<>>(reinterpret_cast(qk_buf_), - (const half *)bias_qk, - batch_size, - head_num, - seq_len); + SOFTMAX_KERNEL_WITH_MASK(1); #endif } else { SoftmaxKernelWithEltadd2<__half2><<>>( @@ -888,13 +888,36 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, seq_len / 2, FINAL_MASK); } else { - SoftmaxKernelWithEltaddForLarge2<__half2><<>>( - reinterpret_cast<__half2 *>(qk_buf_), - reinterpret_cast(bias_qk), - batch_size, - head_num, - seq_len / 2, - FINAL_MASK); + if (bias_is_mask) { +#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) + PADDLE_ENFORCE_EQ(bias_is_mask, + false, + platform::errors::InvalidArgument( + "QK_bias is mask can't be supported on rocm or " + "cuda_arch<700")); +#else + dim3 grid(seq_len, batch_size, head_num); + dim3 block((seq_len / 2 + 31) / 32 * 32); + if (block.x > 0 && block.x <= 1024) { + SOFTMAX_KERNEL_WITH_MASK(1); + } else if (block.x <= 2048) { + SOFTMAX_KERNEL_WITH_MASK(2); + } else if (block.x <= 4096) { + SOFTMAX_KERNEL_WITH_MASK(4); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Cannot support the length of attention > 8192.")); + } +#endif + } else { + SoftmaxKernelWithEltaddForLarge2<__half2><<>>( + reinterpret_cast<__half2 *>(qk_buf_), + reinterpret_cast(bias_qk), + batch_size, + head_num, + seq_len / 2, + FINAL_MASK); + } } } else { SoftmaxKernelWithEltaddForLarge<<>>( -- GitLab