diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu index af2ab0a8a70a36aa0d69221d27916533f3d16afa..fa663528eb0158fa75f8d3e86f83115621491816 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cu +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -463,12 +463,17 @@ void DispatchConcatKernel(const phi::GPUContext& ctx, constexpr IndexT MaxVecSize = 16 / sizeof(T); bool find_vecsize_flag = false; IndexT dispatch_vec_size = 1; + + auto output_data = reinterpret_cast(output->data()); for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) { - for (IndexT idx = 0; idx < in_num + 1; idx++) { + const IndexT mov_size = vec_size * sizeof(T); + for (IndexT idx = 1; idx < in_num + 1; idx++) { + auto input_data = reinterpret_cast(inputs_data[idx - 1]); // Since input_cols[0] is 0, we need to jump. - const IndexT input_col = inputs_col[idx + 1] - inputs_col[idx]; - if (input_col % vec_size == 0) { - if (idx == in_num - 1) { + const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1]; + if (input_col % vec_size == 0 && output_data % mov_size == 0 && + input_data % mov_size == 0) { + if (idx == in_num) { find_vecsize_flag = true; } } else {