未验证 提交 0661e8f1 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

Add output defs for edit_distance kernel (#51324)

* add output defs for edit_distance kernel

* change seqnum as output0
上级 7c9ccf5f
......@@ -65,7 +65,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"complex",
"conv3d_coo",
"distribute_fpn_proposals",
"edit_distance",
"eig",
"eig_grad",
"eigh",
......
......@@ -121,4 +121,6 @@ void EditDistanceKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
edit_distance, CPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {}
edit_distance, CPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -184,4 +184,6 @@ void EditDistanceKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
edit_distance, GPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {}
edit_distance, GPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册