diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index fe3ea180593b914b1fec948644723ec0a535b4d7..240ecaa25893d04fe4836d08998a312582425f2f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -115,7 +115,18 @@ inline void TransposeQKV(const int batch, const int seq_len, const half *input, half *output, cudaStream_t stream) { int scratch_size = batch * head_num * seq_len * seq_len; const dim3 grid(seq_len, batch, 3); - if (head_size % 2 == 0 && scratch_size % 2 == 0) { + if (head_size % 8 == 0 && scratch_size % 8 == 0) { + int h = head_size / 8; + const int4 *input4 = reinterpret_cast(input); + int4 *output4 = reinterpret_cast(output); + dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, head_size, 1024 * 8)); + TransposeQkvKernel<<>>(h, input4, output4); + } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { const int h = head_size / 2; const half2 *input2 = reinterpret_cast(input); half2 *output2 = reinterpret_cast(output);