未验证 提交 545e20f8 编写于 作者: S Sanbu 提交者: GitHub

[phi] Add output defs for argsort kernel (#51407)

* Add output defs for argsort kernel

* Update argsort_kernel.cc

* Update argsort_kernel.cu

* Update argsort_kernel.cc
上级 f5811a60
...@@ -54,7 +54,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -54,7 +54,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"adam", "adam",
"adamw", "adamw",
"any_raw", "any_raw",
"arg_sort",
"clip_by_norm", "clip_by_norm",
"eig_grad", "eig_grad",
"eigh", "eigh",
......
...@@ -154,4 +154,5 @@ void ArgsortKernel(const Context& dev_ctx, ...@@ -154,4 +154,5 @@ void ArgsortKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
argsort, CPU, ALL_LAYOUT, phi::ArgsortKernel, float, double, int, int64_t) { argsort, CPU, ALL_LAYOUT, phi::ArgsortKernel, float, double, int, int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -519,4 +519,6 @@ PD_REGISTER_KERNEL(argsort, ...@@ -519,4 +519,6 @@ PD_REGISTER_KERNEL(argsort,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
...@@ -248,4 +248,6 @@ PD_REGISTER_KERNEL(argsort, ...@@ -248,4 +248,6 @@ PD_REGISTER_KERNEL(argsort,
float, float,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册