未验证 提交 262358e8 编写于 作者: A Ainavo 提交者: GitHub

add output defs for nonzero and nms (#51325)

上级 35d31e9a
......@@ -94,8 +94,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"multiclass_nms3",
"multinomial",
"nanmedian",
"nms",
"nonzero",
"numl",
"qr",
"qr_grad",
......
......@@ -94,4 +94,6 @@ void NMSKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {}
PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
......@@ -92,4 +92,6 @@ PD_REGISTER_KERNEL(nonzero,
int16_t,
bool,
float,
double) {}
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
......@@ -113,4 +113,6 @@ void NMSKernel(const Context& dev_ctx,
dev_ctx.stream());
}
} // namespace phi
PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {}
PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
......@@ -83,4 +83,6 @@ PD_REGISTER_KERNEL(nonzero,
int16_t,
bool,
float,
double) {}
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
......@@ -69,4 +69,6 @@ void NonZeroKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {}
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册