未验证 提交 24379442 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Add align check for Concat Kernel (#49761)

* add align check

* refine
上级 55ccb429
......@@ -463,12 +463,17 @@ void DispatchConcatKernel(const phi::GPUContext& ctx,
constexpr IndexT MaxVecSize = 16 / sizeof(T);
bool find_vecsize_flag = false;
IndexT dispatch_vec_size = 1;
auto output_data = reinterpret_cast<std::uintptr_t>(output->data());
for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) {
for (IndexT idx = 0; idx < in_num + 1; idx++) {
const IndexT mov_size = vec_size * sizeof(T);
for (IndexT idx = 1; idx < in_num + 1; idx++) {
auto input_data = reinterpret_cast<std::uintptr_t>(inputs_data[idx - 1]);
// Since input_cols[0] is 0, we need to jump.
const IndexT input_col = inputs_col[idx + 1] - inputs_col[idx];
if (input_col % vec_size == 0) {
if (idx == in_num - 1) {
const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1];
if (input_col % vec_size == 0 && output_data % mov_size == 0 &&
input_data % mov_size == 0) {
if (idx == in_num) {
find_vecsize_flag = true;
}
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册