未验证 提交 35d31e9a 编写于 作者: H Huang Jiyi 提交者: GitHub

Add output defs for some kernels (#51333)

上级 4050ca0e
......@@ -88,10 +88,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"layer_norm_grad",
"less_equal",
"less_than",
"lstsq",
"lu",
"matrix_nms",
"matrix_rank_tol",
"merged_adam",
"mode",
"momentum",
......
......@@ -301,4 +301,6 @@ void LstsqKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {}
PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -73,4 +73,7 @@ void LUKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) {}
PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -318,4 +318,7 @@ void MatrixNMSKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) {}
matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -175,4 +175,5 @@ void MatrixRankTolKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
matrix_rank_tol, CPU, ALL_LAYOUT, phi::MatrixRankTolKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -183,6 +183,9 @@ PD_REGISTER_KERNEL(lu, // cuda_only
ALL_LAYOUT,
phi::LUKernel,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
#endif // not PADDLE_WITH_HIP
......@@ -445,6 +445,8 @@ PD_REGISTER_KERNEL(matrix_rank_tol, // cuda_only
ALL_LAYOUT,
phi::MatrixRankTolKernel,
float,
double) {}
double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif // not PADDLE_WITH_HIP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册