diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 6fa66e9da5a437d70f217951445e57a1ed7af378..2fb277e9032b2db94c6c268893ba8bfa88465f6d 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -118,7 +118,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "sgd", "svd", "sync_batch_norm_grad", - "top_k", "unique", "unique_consecutive_flattened_tensor", "unique_raw", diff --git a/paddle/phi/kernels/cpu/top_k_kernel.cc b/paddle/phi/kernels/cpu/top_k_kernel.cc index 5a5789effadf95da04c5e160896714bb12c4be96..4f67d21777e9e376b8fb4d9bac5792bbd325f684 100644 --- a/paddle/phi/kernels/cpu/top_k_kernel.cc +++ b/paddle/phi/kernels/cpu/top_k_kernel.cc @@ -241,4 +241,6 @@ void TopkKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {} + topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) { + kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 01bae0dd96cc5f25a1115d3b64e6f5cc5e15fb19..cd6294eaf057307e0b2a03b75f88ce417eab0773 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -348,4 +348,6 @@ PD_REGISTER_KERNEL(topk, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); +} diff --git a/paddle/phi/kernels/xpu/top_k_kernel.cc b/paddle/phi/kernels/xpu/top_k_kernel.cc index fca852a086d506d7092a1202722bc5c2c0613854..19b9209303bea3c83b77a7347aae1d57bc3eba9f 100644 --- a/paddle/phi/kernels/xpu/top_k_kernel.cc +++ b/paddle/phi/kernels/xpu/top_k_kernel.cc @@ -187,4 +187,6 @@ void TopkKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - topk, XPU, ALL_LAYOUT, phi::TopkKernel, float, phi::dtype::float16) {} + topk, XPU, ALL_LAYOUT, phi::TopkKernel, float, phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); +}