未验证 提交 89fb1964 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick-2.2.1]Opt topk (#37325)

目前的fused_attention_op不支持attn_mask=None的输入,本PR对此进行了补充,并补充了相应的单测逻辑。
上级 d31d597f
...@@ -83,7 +83,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
if (k > input_width) k = input_width; if (k > input_width) k = input_width;
if ((input_width <= 1024 || k >= 128 || k == input_width)) { // The conclusion is drawn from the data through multiple sets of
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output, if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
indices, largest)) { indices, largest)) {
// Successed, return. // Successed, return.
...@@ -159,8 +161,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> { ...@@ -159,8 +161,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
if (k > input_width) k = input_width; if (k > input_width) k = input_width;
if (((input_width <= 1024 && input_height <= 2048) || k >= 128 || // The conclusion is drawn from the data through multiple sets of
k == input_width)) { // statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (SortTopk<T>(dev_ctx, &trans_input, input_width, input_height, k, if (SortTopk<T>(dev_ctx, &trans_input, input_width, input_height, k,
&trans_out, &trans_ind, largest)) { &trans_out, &trans_ind, largest)) {
// last step, tranpose back the indices and output // last step, tranpose back the indices and output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册