未验证 提交 8c6fde9e 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix align error (#23090)

test=develop
上级 915b892a
......@@ -370,8 +370,10 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
const int head_num, const float *input, const float *bias,
float *output, cudaStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 4 == 0) {
// scratch % 4 == 0 to ensure the alignment
if (head_size % 4 == 0 && scratch_size % 4 == 0) {
const int h = head_size / 4;
const float4 *input4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
......@@ -385,7 +387,7 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
head_num, head_size, 1024 * 4));
transpose_qkv_kernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
output4);
} else if (head_size % 2 == 0) {
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const float2 *input2 = reinterpret_cast<const float2 *>(input);
const float2 *bias2 = reinterpret_cast<const float2 *>(bias);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册