From 0661e8f15372e486d92a7ae96029db8a3d2df7fa Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Thu, 9 Mar 2023 14:07:12 +0800 Subject: [PATCH] Add output defs for edit_distance kernel (#51324) * add output defs for edit_distance kernel * change seqnum as output0 --- .../framework/new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/kernels/cpu/edit_distance_kernel.cc | 4 +++- paddle/phi/kernels/gpu/edit_distance_kernel.cu | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 413e0b74c09..91802f04207 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -65,7 +65,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "complex", "conv3d_coo", "distribute_fpn_proposals", - "edit_distance", "eig", "eig_grad", "eigh", diff --git a/paddle/phi/kernels/cpu/edit_distance_kernel.cc b/paddle/phi/kernels/cpu/edit_distance_kernel.cc index 190bc3fa552..7e77cc719b8 100644 --- a/paddle/phi/kernels/cpu/edit_distance_kernel.cc +++ b/paddle/phi/kernels/cpu/edit_distance_kernel.cc @@ -121,4 +121,6 @@ void EditDistanceKernel(const Context& ctx, } // namespace phi PD_REGISTER_KERNEL( - edit_distance, CPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {} + edit_distance, CPU, ALL_LAYOUT, phi::EditDistanceKernel, float) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/edit_distance_kernel.cu b/paddle/phi/kernels/gpu/edit_distance_kernel.cu index cb5b096ba3f..6ff1706f7a9 100644 --- a/paddle/phi/kernels/gpu/edit_distance_kernel.cu +++ b/paddle/phi/kernels/gpu/edit_distance_kernel.cu @@ -184,4 +184,6 @@ void EditDistanceKernel(const Context& ctx, } // namespace phi PD_REGISTER_KERNEL( - edit_distance, GPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {} + edit_distance, GPU, ALL_LAYOUT, phi::EditDistanceKernel, float) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT64); +} -- GitLab