未验证 提交 383a3f8c 编写于 作者: Z Zhenghai Zhang 提交者: GitHub

Add output defs for mode kernel (#51363)

* Add output defs for mode kernel

* fix bug
上级 d3ebf1e6
......@@ -74,7 +74,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"less_equal",
"less_than",
"merged_adam",
"mode",
"momentum",
"multiclass_nms3",
"multinomial",
......
......@@ -132,4 +132,6 @@ void ModeKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
mode, CPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {}
mode, CPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -130,4 +130,6 @@ void ModeKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
mode, GPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {}
mode, GPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册