未验证 提交 42f35841 编写于 作者: F feng_shuai 提交者: GitHub

fix: supoort huge length of attention (#48053)

上级 85598e31
......@@ -783,6 +783,19 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
}
#define SOFTMAX_KERNEL_WITH_MASK(REPEAT_THREAD) \
do { \
block.x /= REPEAT_THREAD; \
grid.x /= 4; \
constexpr int NUM = 4; \
softmax_kernel_with_mask<half, REPEAT_THREAD, NUM> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_), \
(const half *)bias_qk, \
batch_size, \
head_num, \
seq_len); \
} while (0)
template <typename T>
inline void MatMulWithHeadQK(const phi::GPUContext &context,
int head_num,
......@@ -843,22 +856,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"));
#else
constexpr int ITEMS_PER_THREAD = 1;
bool is_half2 = true;
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + 31) / 32 * 32);
block.x /= ITEMS_PER_THREAD;
assert(block.x <= 1024);
assert(grid.x % 4 == 0);
grid.x /= 4;
constexpr int NUM = 4;
softmax_kernel_with_mask<half, ITEMS_PER_THREAD, NUM>
<<<grid, block, 0, stream>>>(reinterpret_cast<half *>(qk_buf_),
(const half *)bias_qk,
batch_size,
head_num,
seq_len);
SOFTMAX_KERNEL_WITH_MASK(1);
#endif
} else {
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
......@@ -888,13 +888,36 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
seq_len / 2,
FINAL_MASK);
} else {
SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
if (bias_is_mask) {
#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700)
PADDLE_ENFORCE_EQ(bias_is_mask,
false,
platform::errors::InvalidArgument(
"QK_bias is mask can't be supported on rocm or "
"cuda_arch<700"));
#else
dim3 grid(seq_len, batch_size, head_num);
dim3 block((seq_len / 2 + 31) / 32 * 32);
if (block.x > 0 && block.x <= 1024) {
SOFTMAX_KERNEL_WITH_MASK(1);
} else if (block.x <= 2048) {
SOFTMAX_KERNEL_WITH_MASK(2);
} else if (block.x <= 4096) {
SOFTMAX_KERNEL_WITH_MASK(4);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot support the length of attention > 8192."));
}
#endif
} else {
SoftmaxKernelWithEltaddForLarge2<__half2><<<grid, block, 0, stream>>>(
reinterpret_cast<__half2 *>(qk_buf_),
reinterpret_cast<const __half2 *>(bias_qk),
batch_size,
head_num,
seq_len / 2,
FINAL_MASK);
}
}
} else {
SoftmaxKernelWithEltaddForLarge<T><<<grid, block, 0, stream>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册