From 29453da112c8530b64bda8fbb86ec458226977bf Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Mon, 14 Mar 2022 10:48:54 +0800 Subject: [PATCH] Fix bug when eigen_device() is nullptr in top_k (#40459) --- paddle/phi/kernels/gpu/top_k_kernel.cu | 38 ++++++++++++++------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 4e9aa88c6cb..7f06af7de43 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -78,15 +78,16 @@ void TopkKernel(const Context& dev_ctx, // The conclusion is drawn from the data through multiple sets of // statistics if (input_width >= 128 && k >= input_width * 0.75) { - if (ops::SortTopk( - paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()), - input, - input_width, - input_height, - k, - out, - indices, - largest)) { + auto* ctx = reinterpret_cast( + &dev_ctx); + if (ops::SortTopk(*ctx, + input, + input_width, + input_height, + k, + out, + indices, + largest)) { // Successed, return. return; } else { @@ -181,15 +182,16 @@ void TopkKernel(const Context& dev_ctx, // The conclusion is drawn from the data through multiple sets of // statistics if (input_width >= 128 && k >= input_width * 0.75) { - if (ops::SortTopk( - paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()), - &trans_input, - input_width, - input_height, - k, - &trans_out, - &trans_ind, - largest)) { + auto* ctx = reinterpret_cast( + &dev_ctx); + if (ops::SortTopk(*ctx, + &trans_input, + input_width, + input_height, + k, + &trans_out, + &trans_ind, + largest)) { // last step, tranpose back the indices and output funcs::TransCompute( ndims, dev_ctx, trans_ind, indices, trans); -- GitLab