未验证 提交 09241d85 编写于 作者: S Sanbu 提交者: GitHub

Add output defs for conv3d_coo distribute_fpn_proposals kernel (#51516)

* Add output defs for conv3d_coo distribute_fpn_proposals kernel

* fix
上级 f9a4f007
...@@ -58,8 +58,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -58,8 +58,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"atan2", "atan2",
"clip_by_norm", "clip_by_norm",
"complex", "complex",
"conv3d_coo",
"distribute_fpn_proposals",
"eig", "eig",
"eig_grad", "eig_grad",
"eigh", "eigh",
......
...@@ -142,4 +142,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals, ...@@ -142,4 +142,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals,
ALL_LAYOUT, ALL_LAYOUT,
phi::DistributeFpnProposalsKernel, phi::DistributeFpnProposalsKernel,
float, float,
double) {} double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
...@@ -266,4 +266,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals, ...@@ -266,4 +266,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals,
ALL_LAYOUT, ALL_LAYOUT,
phi::DistributeFpnProposalsKernel, phi::DistributeFpnProposalsKernel,
float, float,
double) {} double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
...@@ -205,4 +205,7 @@ void Conv3dCooKernel(const Context& dev_ctx, ...@@ -205,4 +205,7 @@ void Conv3dCooKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
conv3d_coo, CPU, ALL_LAYOUT, phi::sparse::Conv3dCooKernel, float, double) { conv3d_coo, CPU, ALL_LAYOUT, phi::sparse::Conv3dCooKernel, float, double) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(paddle::DataType::INT32);
kernel->OutputAt(2).SetDataType(paddle::DataType::INT32);
} }
...@@ -295,4 +295,7 @@ PD_REGISTER_KERNEL(conv3d_coo, ...@@ -295,4 +295,7 @@ PD_REGISTER_KERNEL(conv3d_coo,
double, double,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(paddle::DataType::INT32);
kernel->OutputAt(2).SetDataType(paddle::DataType::INT32);
} }
...@@ -180,4 +180,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals, ...@@ -180,4 +180,7 @@ PD_REGISTER_KERNEL(distribute_fpn_proposals,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::DistributeFpnProposalsKernel, phi::DistributeFpnProposalsKernel,
float) {} float) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册