未验证 提交 939b58b2 编写于 作者: S Sanbu 提交者: GitHub

Add output defs for generate_proposals,instance_norm kernel (#51576)

* Add output defs for generate_proposals,instance_norm kernel

* fix
上级 d021095e
...@@ -61,10 +61,8 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -61,10 +61,8 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"eigh", "eigh",
"ftt_c2r", "ftt_c2r",
"ftt_r2c", "ftt_r2c",
"generate_proposals",
"graph_sample_neighbors", "graph_sample_neighbors",
"group_norm", "group_norm",
"instance_norm",
"lamb", "lamb",
"layer_norm", "layer_norm",
"layer_norm_grad", "layer_norm_grad",
......
...@@ -389,4 +389,6 @@ PD_REGISTER_KERNEL(generate_proposals, ...@@ -389,4 +389,6 @@ PD_REGISTER_KERNEL(generate_proposals,
ALL_LAYOUT, ALL_LAYOUT,
phi::GenerateProposalsKernel, phi::GenerateProposalsKernel,
float, float,
double) {} double) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
...@@ -580,4 +580,6 @@ void GenerateProposalsKernel(const Context &ctx, ...@@ -580,4 +580,6 @@ void GenerateProposalsKernel(const Context &ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
generate_proposals, GPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {} generate_proposals, GPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
...@@ -409,4 +409,6 @@ void GenerateProposalsKernel(const Context& dev_ctx, ...@@ -409,4 +409,6 @@ void GenerateProposalsKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {} generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册