未验证 提交 667082c0 编写于 作者: C carryyu 提交者: GitHub

fix P40 topk: Make the optimized topk compatible with P40. (#46547)

* fix P40 topk: Make the optimized topk compatible with P40.

* fix P40 topk: Make the optimized topk compatible with P40.

* fix P40 topk: Make the optimized topk compatible with P40.
上级 d71f1b3f
......@@ -354,16 +354,11 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
}
if (--(*k) == 0) break;
if (MaxLength < 5) {
if (*beam >= MaxLength) break;
} else {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (tid_max / 32 == wid) {
if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) ==
MaxLength)
break;
}
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (tid_max / 32 == wid) {
if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength)
break;
}
}
}
......
......@@ -127,7 +127,8 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
input_height));
default:
PADDLE_THROW(platform::errors::Fatal(
"the input k has error in the topk cuda kernel."));
"the input k has error when use getMaxLength function to get the "
"maxLength."));
});
default:
PADDLE_THROW(platform::errors::Unavailable(
......
......@@ -205,7 +205,8 @@ void TopkKernel(const Context& dev_ctx,
largest));
default:
PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel."));
errors::Fatal("the input k has error when use getMaxLength "
"function to get the maxLength."));
});
#endif
default:
......@@ -313,7 +314,8 @@ void TopkKernel(const Context& dev_ctx,
largest));
default:
PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel."));
errors::Fatal("the input k has error when use getMaxLength "
"function to get the maxLength."));
});
#endif
default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册