未验证 提交 fc93266b 编写于 作者: J Jeng Bai-Cheng 提交者: GitHub

Improve qkv transpose performance (#23919)

Use vector instruction (LDG.128) to improve qkv transpose. It
provides 1.4X speedup at same GPU base frequency.
test=develop
上级 5b573c58
...@@ -115,7 +115,18 @@ inline void TransposeQKV(const int batch, const int seq_len, ...@@ -115,7 +115,18 @@ inline void TransposeQKV(const int batch, const int seq_len,
const half *input, half *output, cudaStream_t stream) { const half *input, half *output, cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len; int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3); const dim3 grid(seq_len, batch, 3);
if (head_size % 2 == 0 && scratch_size % 2 == 0) { if (head_size % 8 == 0 && scratch_size % 8 == 0) {
int h = head_size / 8;
const int4 *input4 = reinterpret_cast<const int4 *>(input);
int4 *output4 = reinterpret_cast<int4 *>(output);
dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 8));
TransposeQkvKernel<int4><<<grid, block, 0, stream>>>(h, input4, output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2; const int h = head_size / 2;
const half2 *input2 = reinterpret_cast<const half2 *>(input); const half2 *input2 = reinterpret_cast<const half2 *>(input);
half2 *output2 = reinterpret_cast<half2 *>(output); half2 *output2 = reinterpret_cast<half2 *>(output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册