未验证 提交 2847980c 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

Add output defs for fused_adam kernel (#51323)

* add output defs for fused_adam kernel

* complete the oters defs for cpu and gpu

* remove register for param_out
上级 c0f84b8f
......@@ -70,7 +70,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"eigh",
"ftt_c2r",
"ftt_r2c",
"fused_adam",
"fused_matmul",
"generate_proposals",
"graph_sample_neighbors",
......
......@@ -158,4 +158,10 @@ void FusedAdamKernel(
} // namespace phi
PD_REGISTER_KERNEL(
fused_adam, CPU, ALL_LAYOUT, phi::FusedAdamKernel, float, double) {}
fused_adam, CPU, ALL_LAYOUT, phi::FusedAdamKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(4).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(5).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -498,4 +498,9 @@ PD_REGISTER_KERNEL(fused_adam,
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(4).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(5).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册