From 8c6fde9e691341085238adb9ffe10a43f57b844a Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Thu, 19 Mar 2020 10:29:02 +0800 Subject: [PATCH] fix align error (#23090) test=develop --- paddle/fluid/operators/fused/multihead_matmul_op.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index de40ded24e..020e5efba8 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -370,8 +370,10 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, const int head_num, const float *input, const float *bias, float *output, cudaStream_t stream) { // BxSx3xNxH + 3xNxH -> 3xBxNxSxH + int scratch_size = batch * head_num * seq_len * seq_len; const dim3 grid(seq_len, batch, 3); - if (head_size % 4 == 0) { + // scratch % 4 == 0 to ensure the alignment + if (head_size % 4 == 0 && scratch_size % 4 == 0) { const int h = head_size / 4; const float4 *input4 = reinterpret_cast(input); const float4 *bias4 = reinterpret_cast(bias); @@ -385,7 +387,7 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, head_num, head_size, 1024 * 4)); transpose_qkv_kernel<<>>(h, input4, bias4, output4); - } else if (head_size % 2 == 0) { + } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { const int h = head_size / 2; const float2 *input2 = reinterpret_cast(input); const float2 *bias2 = reinterpret_cast(bias); -- GitLab