diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index de40ded24e3791dca72abd48345c3a149e4a11a4..020e5efba88d2cc6d042685baf82bbc95f3faac7 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -370,8 +370,10 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, const int head_num, const float *input, const float *bias, float *output, cudaStream_t stream) { // BxSx3xNxH + 3xNxH -> 3xBxNxSxH + int scratch_size = batch * head_num * seq_len * seq_len; const dim3 grid(seq_len, batch, 3); - if (head_size % 4 == 0) { + // scratch % 4 == 0 to ensure the alignment + if (head_size % 4 == 0 && scratch_size % 4 == 0) { const int h = head_size / 4; const float4 *input4 = reinterpret_cast(input); const float4 *bias4 = reinterpret_cast(bias); @@ -385,7 +387,7 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, head_num, head_size, 1024 * 4)); transpose_qkv_kernel<<>>(h, input4, bias4, output4); - } else if (head_size % 2 == 0) { + } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { const int h = head_size / 2; const float2 *input2 = reinterpret_cast(input); const float2 *bias2 = reinterpret_cast(bias);