未验证 提交 383a3f8c 编写于 作者: Z Zhenghai Zhang 提交者: GitHub

Add output defs for mode kernel (#51363)

* Add output defs for mode kernel

* fix bug
上级 d3ebf1e6
...@@ -74,7 +74,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -74,7 +74,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"less_equal", "less_equal",
"less_than", "less_than",
"merged_adam", "merged_adam",
"mode",
"momentum", "momentum",
"multiclass_nms3", "multiclass_nms3",
"multinomial", "multinomial",
......
...@@ -132,4 +132,6 @@ void ModeKernel(const Context& dev_ctx, ...@@ -132,4 +132,6 @@ void ModeKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
mode, CPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {} mode, CPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
...@@ -130,4 +130,6 @@ void ModeKernel(const Context& dev_ctx, ...@@ -130,4 +130,6 @@ void ModeKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
mode, GPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {} mode, GPU, ALL_LAYOUT, phi::ModeKernel, float, double, int32_t, int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册