diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 8cb8b7f4b7e2044a706ec80abc052d62afa7e8cb..5e3f078cf9f4d586501d10ec34be8ac25ea8868a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -78,8 +78,6 @@ __global__ void transpose_qkv_padding( qkv_id * head_num * size_per_head + head_id * size_per_head; if (seq_id < real_seq_len) { dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset]; - } else if (seq_id < seq_len) { - dst[threadIdx.x + dst_offset] = 0; } } @@ -91,14 +89,69 @@ __global__ void transpose_qkv_unpadding(const T *src, const int head_num, const int size_per_head, const int real_seq_len) { - int batch_id = blockIdx.x / (head_num * real_seq_len); - int seq_id = blockIdx.x % real_seq_len; - int head_id = blockIdx.x % (head_num * real_seq_len) / real_seq_len; - dst[batch_id * head_num * real_seq_len * size_per_head + - seq_id * head_num * size_per_head + head_id * size_per_head + - threadIdx.x] = src[batch_id * head_num * seq_len * size_per_head + + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; + const int src_offset = batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head + - seq_id * size_per_head + threadIdx.x]; + seq_id * size_per_head; + const int dst_offset = batch_id * real_seq_len * head_num * size_per_head + + seq_id * head_num * size_per_head + + head_id * size_per_head; + + dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset]; +} + +#define LAUNCH_TRANSPOSE_KERNEL(TYPE, VECTOR_SIZE, PAD_TYPE) \ + do { \ + int h = head_size / VECTOR_SIZE; \ + const TYPE *input##VECTOR_SIZE = reinterpret_cast(input); \ + TYPE *output##VECTOR_SIZE = reinterpret_cast(output); \ + dim3 block(h, head_num, 1); \ + transpose_qkv_##PAD_TYPE \ + <<>>(input##VECTOR_SIZE, \ + output##VECTOR_SIZE, \ + batch, \ + seq_len, \ + head_num, \ + h, \ + real_seq_len); \ + } while (0) + +inline void TransposePadding(const half *input, + half *output, + const int batch, + const int seq_len, + const int head_num, + const int head_size, + const int real_seq_len, + cudaStream_t stream) { + const dim3 grid(seq_len, batch, 3); + if (head_size % 8 == 0) { + LAUNCH_TRANSPOSE_KERNEL(int4, 8, padding); + } else if (head_size % 2 == 0) { + LAUNCH_TRANSPOSE_KERNEL(half2, 2, padding); + } else { + LAUNCH_TRANSPOSE_KERNEL(half, 1, padding); + } +} + +inline void TransposeUnPadding(const half *input, + half *output, + const int batch, + const int seq_len, + const int head_num, + const int head_size, + const int real_seq_len, + cudaStream_t stream) { + const dim3 grid(real_seq_len, batch); + if (head_size % 8 == 0) { + LAUNCH_TRANSPOSE_KERNEL(int4, 8, unpadding); + } else if (head_size % 2 == 0) { + LAUNCH_TRANSPOSE_KERNEL(half2, 2, unpadding); + } else { + LAUNCH_TRANSPOSE_KERNEL(half, 1, unpadding); + } } int QkvToContextPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } @@ -381,15 +434,14 @@ int QkvToContextPluginDynamic::enqueue( const half *input1_data = static_cast(qk_bias); // BxSx3xNxH => tptr: 3xBxNxSxH. if (need_padding) { - dim3 grid_p(seq_len, batch, 3); - dim3 block_p(head_size_, head_number_, 1); - transpose_qkv_padding<<>>(input0_data, - tptr, - batch, - seq_len, - head_number_, - head_size_, - real_seq_len); + TransposePadding(input0_data, + tptr, + batch, + seq_len, + head_number_, + head_size_, + real_seq_len, + stream); } else { TransposeQKV( batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); @@ -424,10 +476,14 @@ int QkvToContextPluginDynamic::enqueue( int block = head_size_; half *output = static_cast(outputs[0]); if (need_padding) { - int grid_u = batch * head_number_ * real_seq_len; - int block_u = head_size_; - transpose_qkv_unpadding<<>>( - tptr, output, batch, seq_len, head_number_, head_size_, real_seq_len); + TransposeUnPadding(tptr, + output, + batch, + seq_len, + head_number_, + head_size_, + real_seq_len, + stream); } else { transpose<<>>( tptr, output, batch, seq_len, head_number_, head_size_);