未验证 提交 35d31e9a 编写于 作者: H Huang Jiyi 提交者: GitHub

Add output defs for some kernels (#51333)

上级 4050ca0e
...@@ -88,10 +88,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -88,10 +88,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"layer_norm_grad", "layer_norm_grad",
"less_equal", "less_equal",
"less_than", "less_than",
"lstsq",
"lu",
"matrix_nms",
"matrix_rank_tol",
"merged_adam", "merged_adam",
"mode", "mode",
"momentum", "momentum",
......
...@@ -301,4 +301,6 @@ void LstsqKernel(const Context& dev_ctx, ...@@ -301,4 +301,6 @@ void LstsqKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {} PD_REGISTER_KERNEL(lstsq, CPU, ALL_LAYOUT, phi::LstsqKernel, float, double) {
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
...@@ -73,4 +73,7 @@ void LUKernel(const Context& dev_ctx, ...@@ -73,4 +73,7 @@ void LUKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) {} PD_REGISTER_KERNEL(lu, CPU, ALL_LAYOUT, phi::LUKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
...@@ -318,4 +318,7 @@ void MatrixNMSKernel(const Context& ctx, ...@@ -318,4 +318,7 @@ void MatrixNMSKernel(const Context& ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) {} matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
...@@ -175,4 +175,5 @@ void MatrixRankTolKernel(const Context& dev_ctx, ...@@ -175,4 +175,5 @@ void MatrixRankTolKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
matrix_rank_tol, CPU, ALL_LAYOUT, phi::MatrixRankTolKernel, float, double) { matrix_rank_tol, CPU, ALL_LAYOUT, phi::MatrixRankTolKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
} }
...@@ -183,6 +183,9 @@ PD_REGISTER_KERNEL(lu, // cuda_only ...@@ -183,6 +183,9 @@ PD_REGISTER_KERNEL(lu, // cuda_only
ALL_LAYOUT, ALL_LAYOUT,
phi::LUKernel, phi::LUKernel,
float, float,
double) {} double) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
#endif // not PADDLE_WITH_HIP #endif // not PADDLE_WITH_HIP
...@@ -445,6 +445,8 @@ PD_REGISTER_KERNEL(matrix_rank_tol, // cuda_only ...@@ -445,6 +445,8 @@ PD_REGISTER_KERNEL(matrix_rank_tol, // cuda_only
ALL_LAYOUT, ALL_LAYOUT,
phi::MatrixRankTolKernel, phi::MatrixRankTolKernel,
float, float,
double) {} double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif // not PADDLE_WITH_HIP #endif // not PADDLE_WITH_HIP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册