未验证 提交 635958d9 编写于 作者: F feng_shuai 提交者: GitHub

optimize: vectorize transpose_padding (#48116)

上级 42f35841
...@@ -78,8 +78,6 @@ __global__ void transpose_qkv_padding( ...@@ -78,8 +78,6 @@ __global__ void transpose_qkv_padding(
qkv_id * head_num * size_per_head + head_id * size_per_head; qkv_id * head_num * size_per_head + head_id * size_per_head;
if (seq_id < real_seq_len) { if (seq_id < real_seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset]; dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset];
} else if (seq_id < seq_len) {
dst[threadIdx.x + dst_offset] = 0;
} }
} }
...@@ -91,14 +89,69 @@ __global__ void transpose_qkv_unpadding(const T *src, ...@@ -91,14 +89,69 @@ __global__ void transpose_qkv_unpadding(const T *src,
const int head_num, const int head_num,
const int size_per_head, const int size_per_head,
const int real_seq_len) { const int real_seq_len) {
int batch_id = blockIdx.x / (head_num * real_seq_len); int batch_id = blockIdx.y;
int seq_id = blockIdx.x % real_seq_len; int seq_id = blockIdx.x;
int head_id = blockIdx.x % (head_num * real_seq_len) / real_seq_len; int head_id = threadIdx.y;
dst[batch_id * head_num * real_seq_len * size_per_head + const int src_offset = batch_id * head_num * seq_len * size_per_head +
seq_id * head_num * size_per_head + head_id * size_per_head +
threadIdx.x] = src[batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head + head_id * seq_len * size_per_head +
seq_id * size_per_head + threadIdx.x]; seq_id * size_per_head;
const int dst_offset = batch_id * real_seq_len * head_num * size_per_head +
seq_id * head_num * size_per_head +
head_id * size_per_head;
dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset];
}
#define LAUNCH_TRANSPOSE_KERNEL(TYPE, VECTOR_SIZE, PAD_TYPE) \
do { \
int h = head_size / VECTOR_SIZE; \
const TYPE *input##VECTOR_SIZE = reinterpret_cast<const TYPE *>(input); \
TYPE *output##VECTOR_SIZE = reinterpret_cast<TYPE *>(output); \
dim3 block(h, head_num, 1); \
transpose_qkv_##PAD_TYPE<TYPE> \
<<<grid, block, 0, stream>>>(input##VECTOR_SIZE, \
output##VECTOR_SIZE, \
batch, \
seq_len, \
head_num, \
h, \
real_seq_len); \
} while (0)
inline void TransposePadding(const half *input,
half *output,
const int batch,
const int seq_len,
const int head_num,
const int head_size,
const int real_seq_len,
cudaStream_t stream) {
const dim3 grid(seq_len, batch, 3);
if (head_size % 8 == 0) {
LAUNCH_TRANSPOSE_KERNEL(int4, 8, padding);
} else if (head_size % 2 == 0) {
LAUNCH_TRANSPOSE_KERNEL(half2, 2, padding);
} else {
LAUNCH_TRANSPOSE_KERNEL(half, 1, padding);
}
}
inline void TransposeUnPadding(const half *input,
half *output,
const int batch,
const int seq_len,
const int head_num,
const int head_size,
const int real_seq_len,
cudaStream_t stream) {
const dim3 grid(real_seq_len, batch);
if (head_size % 8 == 0) {
LAUNCH_TRANSPOSE_KERNEL(int4, 8, unpadding);
} else if (head_size % 2 == 0) {
LAUNCH_TRANSPOSE_KERNEL(half2, 2, unpadding);
} else {
LAUNCH_TRANSPOSE_KERNEL(half, 1, unpadding);
}
} }
int QkvToContextPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } int QkvToContextPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
...@@ -381,15 +434,14 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -381,15 +434,14 @@ int QkvToContextPluginDynamic::enqueue(
const half *input1_data = static_cast<const half *>(qk_bias); const half *input1_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH. // BxSx3xNxH => tptr: 3xBxNxSxH.
if (need_padding) { if (need_padding) {
dim3 grid_p(seq_len, batch, 3); TransposePadding(input0_data,
dim3 block_p(head_size_, head_number_, 1);
transpose_qkv_padding<<<grid_p, block_p, 0, stream>>>(input0_data,
tptr, tptr,
batch, batch,
seq_len, seq_len,
head_number_, head_number_,
head_size_, head_size_,
real_seq_len); real_seq_len,
stream);
} else { } else {
TransposeQKV( TransposeQKV(
batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); batch, seq_len, head_size_, head_number_, input0_data, tptr, stream);
...@@ -424,10 +476,14 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -424,10 +476,14 @@ int QkvToContextPluginDynamic::enqueue(
int block = head_size_; int block = head_size_;
half *output = static_cast<half *>(outputs[0]); half *output = static_cast<half *>(outputs[0]);
if (need_padding) { if (need_padding) {
int grid_u = batch * head_number_ * real_seq_len; TransposeUnPadding(tptr,
int block_u = head_size_; output,
transpose_qkv_unpadding<half><<<grid_u, block_u, 0, stream>>>( batch,
tptr, output, batch, seq_len, head_number_, head_size_, real_seq_len); seq_len,
head_number_,
head_size_,
real_seq_len,
stream);
} else { } else {
transpose<half><<<grid, block, 0, stream>>>( transpose<half><<<grid, block, 0, stream>>>(
tptr, output, batch, seq_len, head_number_, head_size_); tptr, output, batch, seq_len, head_number_, head_size_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册