From 262358e82c2a012bdf999cf43fdacf375970b09d Mon Sep 17 00:00:00 2001 From: Ainavo <57820731+Ainavo@users.noreply.github.com> Date: Wed, 8 Mar 2023 16:59:21 +0800 Subject: [PATCH] add output defs for nonzero and nms (#51325) --- .../framework/new_executor/interpreter/interpreter_util.cc | 2 -- paddle/phi/kernels/cpu/nms_kernel.cc | 4 +++- paddle/phi/kernels/cpu/nonzero_kernel.cc | 4 +++- paddle/phi/kernels/gpu/nms_kernel.cu | 4 +++- paddle/phi/kernels/gpu/nonzero_kernel.cu | 4 +++- paddle/phi/kernels/xpu/nonzero_kernel.cc | 4 +++- 6 files changed, 15 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index bdcb841c16b..2e1e8cf8763 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -94,8 +94,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "multiclass_nms3", "multinomial", "nanmedian", - "nms", - "nonzero", "numl", "qr", "qr_grad", diff --git a/paddle/phi/kernels/cpu/nms_kernel.cc b/paddle/phi/kernels/cpu/nms_kernel.cc index 6743f13fff7..cf663bd8346 100644 --- a/paddle/phi/kernels/cpu/nms_kernel.cc +++ b/paddle/phi/kernels/cpu/nms_kernel.cc @@ -94,4 +94,6 @@ void NMSKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} +PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); +} diff --git a/paddle/phi/kernels/cpu/nonzero_kernel.cc b/paddle/phi/kernels/cpu/nonzero_kernel.cc index fca8e6b09fc..70b2a212307 100644 --- a/paddle/phi/kernels/cpu/nonzero_kernel.cc +++ b/paddle/phi/kernels/cpu/nonzero_kernel.cc @@ -92,4 +92,6 @@ PD_REGISTER_KERNEL(nonzero, int16_t, bool, float, - double) {} + double) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/nms_kernel.cu b/paddle/phi/kernels/gpu/nms_kernel.cu index c8b067444df..017196973de 100644 --- a/paddle/phi/kernels/gpu/nms_kernel.cu +++ b/paddle/phi/kernels/gpu/nms_kernel.cu @@ -113,4 +113,6 @@ void NMSKernel(const Context& dev_ctx, dev_ctx.stream()); } } // namespace phi -PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} +PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/nonzero_kernel.cu b/paddle/phi/kernels/gpu/nonzero_kernel.cu index 11139c7d65d..b4aab6fe6f8 100644 --- a/paddle/phi/kernels/gpu/nonzero_kernel.cu +++ b/paddle/phi/kernels/gpu/nonzero_kernel.cu @@ -83,4 +83,6 @@ PD_REGISTER_KERNEL(nonzero, int16_t, bool, float, - double) {} + double) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); +} diff --git a/paddle/phi/kernels/xpu/nonzero_kernel.cc b/paddle/phi/kernels/xpu/nonzero_kernel.cc index 35d093c4475..aa1e39a897c 100644 --- a/paddle/phi/kernels/xpu/nonzero_kernel.cc +++ b/paddle/phi/kernels/xpu/nonzero_kernel.cc @@ -69,4 +69,6 @@ void NonZeroKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {} + nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64); +} -- GitLab