未验证 提交 4a4fcc2d 编写于 作者: P PPPPzhang 提交者: GitHub

add output defs for send_ue_recv kernel (#51522)

上级 11d7dae9
...@@ -89,7 +89,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -89,7 +89,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"rnn", "rnn",
"search_sort", "search_sort",
"select", "select",
"send_ue_recv",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"unique", "unique",
"unique_consecutive_flattened_tensor", "unique_consecutive_flattened_tensor",
......
...@@ -290,4 +290,6 @@ PD_REGISTER_KERNEL(send_ue_recv, ...@@ -290,4 +290,6 @@ PD_REGISTER_KERNEL(send_ue_recv,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}
...@@ -332,4 +332,6 @@ PD_REGISTER_KERNEL(send_ue_recv, ...@@ -332,4 +332,6 @@ PD_REGISTER_KERNEL(send_ue_recv,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册