未验证 提交 635958d9 编写于 作者: F feng_shuai 提交者: GitHub

optimize: vectorize transpose_padding (#48116)

上级 42f35841
......@@ -78,8 +78,6 @@ __global__ void transpose_qkv_padding(
qkv_id * head_num * size_per_head + head_id * size_per_head;
if (seq_id < real_seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset];
} else if (seq_id < seq_len) {
dst[threadIdx.x + dst_offset] = 0;
}
}
......@@ -91,14 +89,69 @@ __global__ void transpose_qkv_unpadding(const T *src,
const int head_num,
const int size_per_head,
const int real_seq_len) {
int batch_id = blockIdx.x / (head_num * real_seq_len);
int seq_id = blockIdx.x % real_seq_len;
int head_id = blockIdx.x % (head_num * real_seq_len) / real_seq_len;
dst[batch_id * head_num * real_seq_len * size_per_head +
seq_id * head_num * size_per_head + head_id * size_per_head +
threadIdx.x] = src[batch_id * head_num * seq_len * size_per_head +
int batch_id = blockIdx.y;
int seq_id = blockIdx.x;
int head_id = threadIdx.y;
const int src_offset = batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head +
seq_id * size_per_head + threadIdx.x];
seq_id * size_per_head;
const int dst_offset = batch_id * real_seq_len * head_num * size_per_head +
seq_id * head_num * size_per_head +
head_id * size_per_head;
dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset];
}
#define LAUNCH_TRANSPOSE_KERNEL(TYPE, VECTOR_SIZE, PAD_TYPE) \
do { \
int h = head_size / VECTOR_SIZE; \
const TYPE *input##VECTOR_SIZE = reinterpret_cast<const TYPE *>(input); \
TYPE *output##VECTOR_SIZE = reinterpret_cast<TYPE *>(output); \
dim3 block(h, head_num, 1); \
transpose_qkv_##PAD_TYPE<TYPE> \
<<<grid, block, 0, stream>>>(input##VECTOR_SIZE, \
output##VECTOR_SIZE, \
batch, \
seq_len, \
head_num, \
h, \
real_seq_len); \
} while (0)
inline void TransposePadding(const half *input,
half *output,
const int batch,
const int seq_len,
const int head_num,
const int head_size,
const int real_seq_len,
cudaStream_t stream) {
const dim3 grid(seq_len, batch, 3);
if (head_size % 8 == 0) {
LAUNCH_TRANSPOSE_KERNEL(int4, 8, padding);
} else if (head_size % 2 == 0) {
LAUNCH_TRANSPOSE_KERNEL(half2, 2, padding);
} else {
LAUNCH_TRANSPOSE_KERNEL(half, 1, padding);
}
}
inline void TransposeUnPadding(const half *input,
half *output,
const int batch,
const int seq_len,
const int head_num,
const int head_size,
const int real_seq_len,
cudaStream_t stream) {
const dim3 grid(real_seq_len, batch);
if (head_size % 8 == 0) {
LAUNCH_TRANSPOSE_KERNEL(int4, 8, unpadding);
} else if (head_size % 2 == 0) {
LAUNCH_TRANSPOSE_KERNEL(half2, 2, unpadding);
} else {
LAUNCH_TRANSPOSE_KERNEL(half, 1, unpadding);
}
}
int QkvToContextPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
......@@ -381,15 +434,14 @@ int QkvToContextPluginDynamic::enqueue(
const half *input1_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
if (need_padding) {
dim3 grid_p(seq_len, batch, 3);
dim3 block_p(head_size_, head_number_, 1);
transpose_qkv_padding<<<grid_p, block_p, 0, stream>>>(input0_data,
TransposePadding(input0_data,
tptr,
batch,
seq_len,
head_number_,
head_size_,
real_seq_len);
real_seq_len,
stream);
} else {
TransposeQKV(
batch, seq_len, head_size_, head_number_, input0_data, tptr, stream);
......@@ -424,10 +476,14 @@ int QkvToContextPluginDynamic::enqueue(
int block = head_size_;
half *output = static_cast<half *>(outputs[0]);
if (need_padding) {
int grid_u = batch * head_number_ * real_seq_len;
int block_u = head_size_;
transpose_qkv_unpadding<half><<<grid_u, block_u, 0, stream>>>(
tptr, output, batch, seq_len, head_number_, head_size_, real_seq_len);
TransposeUnPadding(tptr,
output,
batch,
seq_len,
head_number_,
head_size_,
real_seq_len,
stream);
} else {
transpose<half><<<grid, block, 0, stream>>>(
tptr, output, batch, seq_len, head_number_, head_size_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册