未验证 提交 a82911a5 编写于 作者: P PuQing 提交者: GitHub

[PHI] Add nanmedian output defs (#51358)

* add nanmedian output defs

* remove the multiclass_nms3 momentum
上级 ca7394cd
...@@ -65,9 +65,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -65,9 +65,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"less_equal", "less_equal",
"less_than", "less_than",
"merged_adam", "merged_adam",
"momentum",
"multiclass_nms3",
"nanmedian",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"unique", "unique",
"unique_consecutive_flattened_tensor", "unique_consecutive_flattened_tensor",
......
...@@ -207,4 +207,6 @@ PD_REGISTER_KERNEL(nanmedian, ...@@ -207,4 +207,6 @@ PD_REGISTER_KERNEL(nanmedian,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
...@@ -286,4 +286,6 @@ PD_REGISTER_KERNEL(nanmedian, ...@@ -286,4 +286,6 @@ PD_REGISTER_KERNEL(nanmedian,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册