diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 54b8c93e92b1d9020648f4695c26c598304cbad0..8ba9e7a70e590cd0b1715e1ebc41b52b3e0ebe76 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -56,8 +56,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "any_raw", "eig_grad", "eigh", - "graph_sample_neighbors", - "group_norm", "lamb", "layer_norm", "layer_norm_grad", diff --git a/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc index f8fefa3450cea3147ae03a3b57f365d2886bdda7..996951b968730201650e6028d282b60134ec3df1 100644 --- a/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_sample_neighbors_kernel.cc @@ -226,4 +226,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors, ALL_LAYOUT, phi::GraphSampleNeighborsKernel, int, - int64_t) {} + int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu index 3ea1dbc8e19c20bfb485ab8c519ba6b5134d3cb1..c01a8ea9d2e016410b4b250c883bedd331d57870 100644 --- a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu @@ -483,4 +483,6 @@ PD_REGISTER_KERNEL(graph_sample_neighbors, ALL_LAYOUT, phi::GraphSampleNeighborsKernel, int, - int64_t) {} + int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu index c23bfd3aa72d9d292eee3399739d6fc23fb1643d..ef39abd9394102d8d68377c7761fd2070518f07e 100644 --- a/paddle/phi/kernels/gpu/group_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -347,4 +347,10 @@ PD_REGISTER_KERNEL(group_norm, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::BFLOAT16 || + kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +}