未验证 提交 59c1ffa8 编写于 作者: 张春乔 提交者: GitHub

add register of VITERBI_DECODE (#51318)

上级 9ffdb2b7
......@@ -105,7 +105,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"unique",
"unique_consecutive_flattened_tensor",
"unique_raw",
"viterbi_decode",
"viterbi_devode",
"yolo_loss"};
......
......@@ -318,4 +318,6 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
viterbi_decode, CPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {}
viterbi_decode, CPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
}
......@@ -397,4 +397,6 @@ void ViterbiDecodeKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
viterbi_decode, GPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {}
viterbi_decode, GPU, ALL_LAYOUT, phi::ViterbiDecodeKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册