From 37bd7e780997b1bc0c1119c4fae38dd91db32d95 Mon Sep 17 00:00:00 2001 From: hjyp <53164956+Tomoko-hjf@users.noreply.github.com> Date: Wed, 29 Mar 2023 17:50:03 +0800 Subject: [PATCH] Add output defines for graph_sample_neighbors and group_norm (#51503) * regist output type for GraphSampleNeighbors and GroupNorm * Update return type * fix return type * update * fix detail --- .../new_executor/interpreter/interpreter_util.cc | 2 -- paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc | 4 +++- paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu | 4 +++- paddle/phi/kernels/gpu/group_norm_kernel.cu | 8 +++++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 54b8c93e92b..8ba9e7a70e5 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -56,8 +56,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "any_raw", "eig_grad", "eigh", - "graph_sample_neighbors", - "group_norm", "lamb", "layer_norm", "layer_norm_grad", diff --git a/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc index f8fefa3450c..996951b9687 100644 --- a/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc @@ -226,4 +226,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors, ALL_LAYOUT, phi::GraphSampleNeighborsKernel, int, - int64_t) {} + int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu index 3ea1dbc8e19..c01a8ea9d2e 100644 --- a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu @@ -483,4 +483,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors, ALL_LAYOUT, phi::GraphSampleNeighborsKernel, int, - int64_t) {} + int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu index c23bfd3aa72..ef39abd9394 100644 --- a/paddle/phi/kernels/gpu/group_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -347,4 +347,10 @@ PD_REGISTER_KERNEL(group_norm, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::BFLOAT16 || + kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +} -- GitLab