From 59c1ffa86a086f6dc8615f262d09c248f9c3264f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Thu, 9 Mar 2023 10:40:09 +0800 Subject: [PATCH] add register of VITERBI_DECODE (#51318) --- .../framework/new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/kernels/cpu/viterbi_decode_kernel.cc | 4 +++- paddle/phi/kernels/gpu/viterbi_decode_kernel.cu | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index edec74d38c9..1f6cd392ec0 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -105,7 +105,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "unique", "unique_consecutive_flattened_tensor", "unique_raw", - "viterbi_decode", "viterbi_devode", "yolo_loss"}; diff --git a/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc b/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc index c520963b172..04c9f22ffe5 100644 --- a/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc +++ b/paddle/phi/kernels/cpu/viterbi_decode_kernel.cc @@ -318,4 +318,6 @@ void ViterbiDecodeKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - viterbi_decode, CPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {} + viterbi_decode, CPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) { + kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu index b80e9253128..10923b11972 100644 --- a/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu +++ b/paddle/phi/kernels/gpu/viterbi_decode_kernel.cu @@ -397,4 +397,6 @@ void ViterbiDecodeKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - viterbi_decode, GPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {} + viterbi_decode, GPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) { + kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64); +} -- GitLab