From 35d31e9af9f19d0220f15fa342c3b265f052c89a Mon Sep 17 00:00:00 2001 From: Huang Jiyi <43315610+huangjiyi@users.noreply.github.com> Date: Wed, 8 Mar 2023 16:52:45 +0800 Subject: [PATCH] Add output defs for some kernels (#51333) --- .../framework/new_executor/interpreter/interpreter_util.cc | 4 ---- paddle/phi/kernels/cpu/lstsq_kernel.cc | 4 +++- paddle/phi/kernels/cpu/lu_kernel.cc | 5 ++++- paddle/phi/kernels/cpu/matrix_nms_kernel.cc | 5 ++++- paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc | 1 + paddle/phi/kernels/gpu/lu_kernel.cu | 5 ++++- paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu | 4 +++- 7 files changed, 19 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 9683f0cb958..bdcb841c16b 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -88,10 +88,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "layer_norm_grad", "less_equal", "less_than", - "lstsq", - "lu", - "matrix_nms", - "matrix_rank_tol", "merged_adam", "mode", "momentum", diff --git a/paddle/phi/kernels/cpu/lstsq_kernel.cc b/paddle/phi/kernels/cpu/lstsq_kernel.cc index 6702ea78393..2e3f9ea8ace 100644 --- a/paddle/phi/kernels/cpu/lstsq_kernel.cc +++ b/paddle/phi/kernels/cpu/lstsq_kernel.cc @@ -301,4 +301,6 @@ void LstsqKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {} +PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) { + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/cpu/lu_kernel.cc b/paddle/phi/kernels/cpu/lu_kernel.cc index 14cbab53663..731a722372d 100644 --- a/paddle/phi/kernels/cpu/lu_kernel.cc +++ b/paddle/phi/kernels/cpu/lu_kernel.cc @@ -73,4 +73,7 @@ void LUKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) {} +PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/cpu/matrix_nms_kernel.cc b/paddle/phi/kernels/cpu/matrix_nms_kernel.cc index aa9f778d1e4..3c51468e7c7 100644 --- a/paddle/phi/kernels/cpu/matrix_nms_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_nms_kernel.cc @@ -318,4 +318,7 @@ void MatrixNMSKernel(const Context& ctx, } // namespace phi PD_REGISTER_KERNEL( - matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) {} + matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc index 491e9c5d210..fbb16138567 100644 --- a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc @@ -175,4 +175,5 @@ void MatrixRankTolKernel(const Context& dev_ctx, PD_REGISTER_KERNEL( matrix_rank_tol, CPU, ALL_LAYOUT, phi::MatrixRankTolKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT64); } diff --git a/paddle/phi/kernels/gpu/lu_kernel.cu b/paddle/phi/kernels/gpu/lu_kernel.cu index 57a7366a239..d26826eccd1 100644 --- a/paddle/phi/kernels/gpu/lu_kernel.cu +++ b/paddle/phi/kernels/gpu/lu_kernel.cu @@ -183,6 +183,9 @@ PD_REGISTER_KERNEL(lu, // cuda_only ALL_LAYOUT, phi::LUKernel, float, - double) {} + double) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} #endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu index 6b3c59505c2..620341f338e 100644 --- a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu +++ b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu @@ -445,6 +445,8 @@ PD_REGISTER_KERNEL(matrix_rank_tol, // cuda_only ALL_LAYOUT, phi::MatrixRankTolKernel, float, - double) {} + double) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT64); +} #endif // not PADDLE_WITH_HIP -- GitLab