未验证 提交 09241d85 编写于 作者: S Sanbu 提交者: GitHub

Add output defs for conv3d_coo distribute_fpn_proposals kernel (#51516)

* Add output defs for conv3d_coo distribute_fpn_proposals kernel

* fix
上级 f9a4f007
......@@ -58,8 +58,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"atan2",
"clip_by_norm",
"complex",
"conv3d_coo",
"distribute_fpn_proposals",
"eig",
"eig_grad",
"eigh",
......
......@@ -142,4 +142,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals,
ALL_LAYOUT,
phi::DistributeFpnProposalsKernel,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -266,4 +266,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals,
ALL_LAYOUT,
phi::DistributeFpnProposalsKernel,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -205,4 +205,7 @@ void Conv3dCooKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
conv3d_coo, CPU, ALL_LAYOUT, phi::sparse::Conv3dCooKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(paddle::DataType::INT32);
kernel->OutputAt(2).SetDataType(paddle::DataType::INT32);
}
......@@ -295,4 +295,7 @@ PD_REGISTER_KERNEL(conv3d_coo,
double,
phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(paddle::DataType::INT32);
kernel->OutputAt(2).SetDataType(paddle::DataType::INT32);
}
......@@ -180,4 +180,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals,
XPU,
ALL_LAYOUT,
phi::DistributeFpnProposalsKernel,
float) {}
float) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册