未验证 提交 615fc429 编写于 作者: R Ryan 提交者: GitHub

[phi] add register of numel/svd (#51356)

* add numel INT64 register

* del numl

* add svd UNDEFINED register

* remove svd register
上级 f951832d
...@@ -89,13 +89,11 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -89,13 +89,11 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"multiclass_nms3", "multiclass_nms3",
"multinomial", "multinomial",
"nanmedian", "nanmedian",
"numl",
"rnn", "rnn",
"search_sort", "search_sort",
"select", "select",
"send_recv", "send_recv",
"send_ue_recv", "send_ue_recv",
"svd",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"unique", "unique",
"unique_consecutive_flattened_tensor", "unique_consecutive_flattened_tensor",
......
...@@ -30,4 +30,6 @@ PD_REGISTER_KERNEL(numel, ...@@ -30,4 +30,6 @@ PD_REGISTER_KERNEL(numel,
phi::dtype::bfloat16, phi::dtype::bfloat16,
float, float,
double, double,
bool) {} bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
...@@ -29,4 +29,6 @@ PD_REGISTER_KERNEL(numel, ...@@ -29,4 +29,6 @@ PD_REGISTER_KERNEL(numel,
phi::dtype::bfloat16, phi::dtype::bfloat16,
float, float,
double, double,
bool) {} bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册