未验证 提交 87c5f23b 编写于 作者: J junxiu777 提交者: GitHub

add register of kthvalue (#51534)

* add register of KthvalueKernel

add register of KthvalueKernel

* Update kthvalue_kernel.cc

* Update kthvalue_kernel.cu
上级 4e9e23cb
......@@ -71,7 +71,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"group_norm",
"histogram",
"instance_norm",
"kthvalue",
"lamb",
"layer_norm",
"layer_norm_grad",
......
......@@ -178,4 +178,6 @@ PD_REGISTER_KERNEL(kthvalue,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -263,4 +263,6 @@ PD_REGISTER_KERNEL(kthvalue,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册