未验证 提交 8c6fde9e 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix align error (#23090)

test=develop
上级 915b892a
...@@ -370,8 +370,10 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, ...@@ -370,8 +370,10 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
const int head_num, const float *input, const float *bias, const int head_num, const float *input, const float *bias,
float *output, cudaStream_t stream) { float *output, cudaStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH // BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3); const dim3 grid(seq_len, batch, 3);
if (head_size % 4 == 0) { // scratch % 4 == 0 to ensure the alignment
if (head_size % 4 == 0 && scratch_size % 4 == 0) {
const int h = head_size / 4; const int h = head_size / 4;
const float4 *input4 = reinterpret_cast<const float4 *>(input); const float4 *input4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias); const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
...@@ -385,7 +387,7 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, ...@@ -385,7 +387,7 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
head_num, head_size, 1024 * 4)); head_num, head_size, 1024 * 4));
transpose_qkv_kernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4, transpose_qkv_kernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
output4); output4);
} else if (head_size % 2 == 0) { } else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2; const int h = head_size / 2;
const float2 *input2 = reinterpret_cast<const float2 *>(input); const float2 *input2 = reinterpret_cast<const float2 *>(input);
const float2 *bias2 = reinterpret_cast<const float2 *>(bias); const float2 *bias2 = reinterpret_cast<const float2 *>(bias);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册