未验证 提交 b5232bf4 编写于 作者: R Ruibiao Chen 提交者: GitHub

Add output defs for topk kernel (#51233)

上级 883b6aba
...@@ -118,7 +118,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -118,7 +118,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"sgd", "sgd",
"svd", "svd",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"top_k",
"unique", "unique",
"unique_consecutive_flattened_tensor", "unique_consecutive_flattened_tensor",
"unique_raw", "unique_raw",
......
...@@ -241,4 +241,6 @@ void TopkKernel(const Context& dev_ctx, ...@@ -241,4 +241,6 @@ void TopkKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {} topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
}
...@@ -348,4 +348,6 @@ PD_REGISTER_KERNEL(topk, ...@@ -348,4 +348,6 @@ PD_REGISTER_KERNEL(topk,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
}
...@@ -187,4 +187,6 @@ void TopkKernel(const Context& dev_ctx, ...@@ -187,4 +187,6 @@ void TopkKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
topk, XPU, ALL_LAYOUT, phi::TopkKernel, float, phi::dtype::float16) {} topk, XPU, ALL_LAYOUT, phi::TopkKernel, float, phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册