diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index cf562043b70afa193183d77d57e03fb0c7e37630..27fdcefc730fbe1636e5dd6d8255d936cee0bd81 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -70,7 +70,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "eigh", "ftt_c2r", "ftt_r2c", - "fused_adam", "fused_matmul", "generate_proposals", "graph_sample_neighbors", diff --git a/paddle/phi/kernels/cpu/fused_adam_kernel.cc b/paddle/phi/kernels/cpu/fused_adam_kernel.cc index 9d71f2469423e5fecc732cdfefffb740ce3bc268..c6434be8077d9abfe881b2f34b4a5143bf7e4d7c 100644 --- a/paddle/phi/kernels/cpu/fused_adam_kernel.cc +++ b/paddle/phi/kernels/cpu/fused_adam_kernel.cc @@ -158,4 +158,10 @@ void FusedAdamKernel( } // namespace phi PD_REGISTER_KERNEL( - fused_adam, CPU, ALL_LAYOUT, phi::FusedAdamKernel, float, double) {} + fused_adam, CPU, ALL_LAYOUT, phi::FusedAdamKernel, float, double) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(4).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(5).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index 644e2085039c5506f448e594535a9f2e6ef7af55..533ef6fd1509c4cd0c87216691ab54f0dd98bd94 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -498,4 +498,9 @@ PD_REGISTER_KERNEL(fused_adam, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(4).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(5).SetDataType(phi::DataType::UNDEFINED); }