diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 08756a91c82c2a07d40a4569e79ec22e8ca9f632..5cd480f425a7266c5340b36ef4f02a52fdadfa4c 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2172,10 +2172,15 @@ void MomentumInferMeta(const MetaTensor& param, auto param_dim = param.dims(); param_out->set_dims(param_dim); + auto MPType = (param.dtype() == phi::DataType::FLOAT16 || + param.dtype() == phi::DataType::BFLOAT16) + ? phi::DataType::FLOAT32 + : param.dtype(); velocity_out->set_dims(param_dim); - + velocity_out->set_dtype(MPType); if (master_param_out) { master_param_out->set_dims(param_dim); + master_param_out->set_dtype(MPType); } } diff --git a/paddle/phi/kernels/gpu/momentum_kernel.cu b/paddle/phi/kernels/gpu/momentum_kernel.cu index 5a4f5d33e6165370afb67960b8eb61200f034229..6d2b51dff64cb00bb968513c03be0cd989bda8cc 100644 --- a/paddle/phi/kernels/gpu/momentum_kernel.cu +++ b/paddle/phi/kernels/gpu/momentum_kernel.cu @@ -24,7 +24,10 @@ PD_REGISTER_KERNEL(momentum, phi::MomentumDenseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, GPU, @@ -32,4 +35,7 @@ PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, phi::MomentumSparseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/xpu/momentum_kernel.cc b/paddle/phi/kernels/xpu/momentum_kernel.cc index ad9cb2e6ef86ef86d09636554bacdd9801b58b30..207bfef37f947ae4ae3bb93bad52fe831de840d9 100644 --- a/paddle/phi/kernels/xpu/momentum_kernel.cc +++ b/paddle/phi/kernels/xpu/momentum_kernel.cc @@ -69,4 +69,7 @@ PD_REGISTER_KERNEL(momentum, ALL_LAYOUT, phi::MomentumDenseKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +}