未验证 提交 37bd7e78 编写于 作者: H hjyp 提交者: GitHub

Add output defines for graph_sample_neighbors and group_norm (#51503)

* regist output type for GraphSampleNeighbors and GroupNorm

* Update return type

* fix return type

* update

* fix detail
上级 a5ca2672
...@@ -56,8 +56,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -56,8 +56,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"any_raw", "any_raw",
"eig_grad", "eig_grad",
"eigh", "eigh",
"graph_sample_neighbors",
"group_norm",
"lamb", "lamb",
"layer_norm", "layer_norm",
"layer_norm_grad", "layer_norm_grad",
......
...@@ -226,4 +226,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors, ...@@ -226,4 +226,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors,
ALL_LAYOUT, ALL_LAYOUT,
phi::GraphSampleNeighborsKernel, phi::GraphSampleNeighborsKernel,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}
...@@ -483,4 +483,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors, ...@@ -483,4 +483,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors,
ALL_LAYOUT, ALL_LAYOUT,
phi::GraphSampleNeighborsKernel, phi::GraphSampleNeighborsKernel,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}
...@@ -347,4 +347,10 @@ PD_REGISTER_KERNEL(group_norm, ...@@ -347,4 +347,10 @@ PD_REGISTER_KERNEL(group_norm,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::BFLOAT16 ||
kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册