未验证 提交 d128c286 编写于 作者: S sunli 提交者: GitHub

optimize index computation of roll (#33909)

上级 b1c458d0
......@@ -36,13 +36,16 @@ __global__ void RollCudaKernel(const T* input, T* output, int64_t N,
}
int64_t output_idx = idx;
int64_t dim_idx, dim_idx_shift;
int64_t new_dim_idx = 0;
#pragma unroll Rank
#pragma unroll
for (size_t i = 0; i < Rank; i++) {
dim_idx = (idx / strides[i]) % sizes[i];
dim_idx_shift = (dim_idx + shifts[i]) % sizes[i];
output_idx = output_idx + (dim_idx_shift - dim_idx) * strides[i];
new_dim_idx = (idx / strides[i]) % sizes[i] + shifts[i];
if (new_dim_idx >= sizes[i]) {
output_idx += (shifts[i] - sizes[i]) * strides[i];
} else {
output_idx += shifts[i] * strides[i];
}
}
output[output_idx] = input[idx];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册