未验证 提交 939b58b2 编写于 作者: S Sanbu 提交者: GitHub

Add output defs for generate_proposals,instance_norm kernel (#51576)

* Add output defs for generate_proposals,instance_norm kernel

* fix
上级 d021095e
......@@ -61,10 +61,8 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"eigh",
"ftt_c2r",
"ftt_r2c",
"generate_proposals",
"graph_sample_neighbors",
"group_norm",
"instance_norm",
"lamb",
"layer_norm",
"layer_norm_grad",
......
......@@ -389,4 +389,6 @@ PD_REGISTER_KERNEL(generate_proposals,
ALL_LAYOUT,
phi::GenerateProposalsKernel,
float,
double) {}
double) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -580,4 +580,6 @@ void GenerateProposalsKernel(const Context &ctx,
} // namespace phi
PD_REGISTER_KERNEL(
generate_proposals, GPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {}
generate_proposals, GPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -409,4 +409,6 @@ void GenerateProposalsKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {}
generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册