From 24379442903e98ec85b0149717e68fa2778b52ae Mon Sep 17 00:00:00 2001 From: ZZK <359521840@qq.com> Date: Wed, 18 Jan 2023 18:44:27 +0800 Subject: [PATCH] Add align check for Concat Kernel (#49761) * add align check * refine --- .../phi/kernels/funcs/concat_and_split_functor.cu | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu index af2ab0a8a70..fa663528eb0 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cu +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -463,12 +463,17 @@ void DispatchConcatKernel(const phi::GPUContext& ctx, constexpr IndexT MaxVecSize = 16 / sizeof(T); bool find_vecsize_flag = false; IndexT dispatch_vec_size = 1; + + auto output_data = reinterpret_cast(output->data()); for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) { - for (IndexT idx = 0; idx < in_num + 1; idx++) { + const IndexT mov_size = vec_size * sizeof(T); + for (IndexT idx = 1; idx < in_num + 1; idx++) { + auto input_data = reinterpret_cast(inputs_data[idx - 1]); // Since input_cols[0] is 0, we need to jump. - const IndexT input_col = inputs_col[idx + 1] - inputs_col[idx]; - if (input_col % vec_size == 0) { - if (idx == in_num - 1) { + const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1]; + if (input_col % vec_size == 0 && output_data % mov_size == 0 && + input_data % mov_size == 0) { + if (idx == in_num) { find_vecsize_flag = true; } } else { -- GitLab