未验证 提交 42f35841 编写于 作者: F feng_shuai 提交者: GitHub

fix: supoort huge length of attention (#48053)

上级 85598e31
...@@ -783,6 +783,19 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_, ...@@ -783,6 +783,19 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
} }
} }
#define SOFTMAX_KERNEL_WITH_MASK(REPEAT_THREAD) \
do { \
block.x /= REPEAT_THREAD; \
grid.x /= 4; \
constexpr int NUM = 4; \
softmax_kernel_with_mask<half, REPEAT_THREAD, NUM> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_), \
(const half *)bias_qk, \
batch_size, \
head_num, \
seq_len); \
} while (0)
template <typename T> template <typename T>
inline void MatMulWithHeadQK(const phi::GPUContext &context, inline void MatMulWithHeadQK(const phi::GPUContext &context,
int head_num, int head_num,
...@@ -843,22 +856,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, ...@@ -843,22 +856,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
"QK_bias is mask can't be supported on rocm or " "QK_bias is mask can't be supported on rocm or "
"cuda_arch<700")); "cuda_arch<700"));
#else #else
constexpr int ITEMS_PER_THREAD = 1;
bool is_half2 = true;
dim3 grid(seq_len, batch_size, head_num); dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + 31) / 32 * 32); dim3 block((seq_len / 2 + 31) / 32 * 32);
block.x /= ITEMS_PER_THREAD; SOFTMAX_KERNEL_WITH_MASK(1);
assert(block.x <= 1024);
assert(grid.x % 4 == 0);
grid.x /= 4;
constexpr int NUM = 4;
softmax_kernel_with_mask<half, ITEMS_PER_THREAD, NUM>
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_),
(const half *)bias_qk,
batch_size,
head_num,
seq_len);
#endif #endif
} else { } else {
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>( SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
...@@ -888,13 +888,36 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context, ...@@ -888,13 +888,36 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
seq_len / 2, seq_len / 2,
FINAL_MASK); FINAL_MASK);
} else { } else {
SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>( if (bias_is_mask) {
reinterpret_cast<__half2 *>(qk_buf_), #if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700)
reinterpret_cast<const __half2 *>(bias_qk), PADDLE_ENFORCE_EQ(bias_is_mask,
batch_size, false,
head_num, platform::errors::InvalidArgument(
seq_len / 2, "QK_bias is mask can't be supported on rocm or "
FINAL_MASK); "cuda_arch<700"));
#else
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + 31) / 32 * 32);
if (block.x > 0 && block.x <= 1024) {
SOFTMAX_KERNEL_WITH_MASK(1);
} else if (block.x <= 2048) {
SOFTMAX_KERNEL_WITH_MASK(2);
} else if (block.x <= 4096) {
SOFTMAX_KERNEL_WITH_MASK(4);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot support the length of attention > 8192."));
}
#endif
} else {
SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
}
} }
} else { } else {
SoftmaxKernelWithEltaddForLarge<T><<<grid, block, 0, stream>>>( SoftmaxKernelWithEltaddForLarge<T><<<grid, block, 0, stream>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册