未验证 提交 262358e8 编写于 作者: A Ainavo 提交者: GitHub

add output defs for nonzero and nms (#51325)

上级 35d31e9a
...@@ -94,8 +94,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -94,8 +94,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"multiclass_nms3", "multiclass_nms3",
"multinomial", "multinomial",
"nanmedian", "nanmedian",
"nms",
"nonzero",
"numl", "numl",
"qr", "qr",
"qr_grad", "qr_grad",
......
...@@ -94,4 +94,6 @@ void NMSKernel(const Context& dev_ctx, ...@@ -94,4 +94,6 @@ void NMSKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} PD_REGISTER_KERNEL(nms, CPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
...@@ -92,4 +92,6 @@ PD_REGISTER_KERNEL(nonzero, ...@@ -92,4 +92,6 @@ PD_REGISTER_KERNEL(nonzero,
int16_t, int16_t,
bool, bool,
float, float,
double) {} double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
...@@ -113,4 +113,6 @@ void NMSKernel(const Context& dev_ctx, ...@@ -113,4 +113,6 @@ void NMSKernel(const Context& dev_ctx,
dev_ctx.stream()); dev_ctx.stream());
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {} PD_REGISTER_KERNEL(nms, GPU, ALL_LAYOUT, phi::NMSKernel, float, double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
...@@ -83,4 +83,6 @@ PD_REGISTER_KERNEL(nonzero, ...@@ -83,4 +83,6 @@ PD_REGISTER_KERNEL(nonzero,
int16_t, int16_t,
bool, bool,
float, float,
double) {} double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
...@@ -69,4 +69,6 @@ void NonZeroKernel(const Context& dev_ctx, ...@@ -69,4 +69,6 @@ void NonZeroKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {} nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册