diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 4e9aa88c6cb2da7fabe3f5d841a313e82b9ebed2..7f06af7de43f7ee234831203c485eaa0b8c86cbf 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -78,15 +78,16 @@ void TopkKernel(const Context& dev_ctx, // The conclusion is drawn from the data through multiple sets of // statistics if (input_width >= 128 && k >= input_width * 0.75) { - if (ops::SortTopk( - paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()), - input, - input_width, - input_height, - k, - out, - indices, - largest)) { + auto* ctx = reinterpret_cast( + &dev_ctx); + if (ops::SortTopk(*ctx, + input, + input_width, + input_height, + k, + out, + indices, + largest)) { // Successed, return. return; } else { @@ -181,15 +182,16 @@ void TopkKernel(const Context& dev_ctx, // The conclusion is drawn from the data through multiple sets of // statistics if (input_width >= 128 && k >= input_width * 0.75) { - if (ops::SortTopk( - paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()), - &trans_input, - input_width, - input_height, - k, - &trans_out, - &trans_ind, - largest)) { + auto* ctx = reinterpret_cast( + &dev_ctx); + if (ops::SortTopk(*ctx, + &trans_input, + input_width, + input_height, + k, + &trans_out, + &trans_ind, + largest)) { // last step, tranpose back the indices and output funcs::TransCompute( ndims, dev_ctx, trans_ind, indices, trans);