From 939b58b2c796ebe01c5b9979551fe9788f68e717 Mon Sep 17 00:00:00 2001 From: Sanbu <96160062+sanbuphy@users.noreply.github.com> Date: Thu, 16 Mar 2023 10:54:48 +0800 Subject: [PATCH] Add output defs for generate_proposals,instance_norm kernel (#51576) * Add output defs for generate_proposals,instance_norm kernel * fix --- .../framework/new_executor/interpreter/interpreter_util.cc | 2 -- paddle/phi/kernels/cpu/generate_proposals_kernel.cc | 4 +++- paddle/phi/kernels/gpu/generate_proposals_kernel.cu | 4 +++- paddle/phi/kernels/xpu/generate_proposals_kernel.cc | 4 +++- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 47b2c2a2b47..a2fcb4a2282 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -61,10 +61,8 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "eigh", "ftt_c2r", "ftt_r2c", - "generate_proposals", "graph_sample_neighbors", "group_norm", - "instance_norm", "lamb", "layer_norm", "layer_norm_grad", diff --git a/paddle/phi/kernels/cpu/generate_proposals_kernel.cc b/paddle/phi/kernels/cpu/generate_proposals_kernel.cc index 4a9569c045c..1b17afb6df7 100644 --- a/paddle/phi/kernels/cpu/generate_proposals_kernel.cc +++ b/paddle/phi/kernels/cpu/generate_proposals_kernel.cc @@ -389,4 +389,6 @@ PD_REGISTER_KERNEL(generate_proposals, ALL_LAYOUT, phi::GenerateProposalsKernel, float, - double) {} + double) { + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/gpu/generate_proposals_kernel.cu b/paddle/phi/kernels/gpu/generate_proposals_kernel.cu index 40df74756c0..38e0e27d99f 100644 --- a/paddle/phi/kernels/gpu/generate_proposals_kernel.cu +++ b/paddle/phi/kernels/gpu/generate_proposals_kernel.cu @@ -580,4 +580,6 @@ void GenerateProposalsKernel(const Context &ctx, } // namespace phi PD_REGISTER_KERNEL( - generate_proposals, GPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {} + generate_proposals, GPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) { + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/xpu/generate_proposals_kernel.cc b/paddle/phi/kernels/xpu/generate_proposals_kernel.cc index c7c202dec2c..367ebfde95a 100644 --- a/paddle/phi/kernels/xpu/generate_proposals_kernel.cc +++ b/paddle/phi/kernels/xpu/generate_proposals_kernel.cc @@ -409,4 +409,6 @@ void GenerateProposalsKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {} + generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) { + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} -- GitLab