From 648ec7959a30eb55008e15fb00acb8812b62ef2d Mon Sep 17 00:00:00 2001 From: PuQing Date: Fri, 24 Mar 2023 11:29:06 +0800 Subject: [PATCH] [PHI]fix momentum dtype infer (#51353) * fix momentum dtype infer * fix momentum datatype * fix on cpu * add momentum --- paddle/phi/infermeta/multiary.cc | 7 ++++++- paddle/phi/kernels/gpu/momentum_kernel.cu | 10 ++++++++-- paddle/phi/kernels/xpu/momentum_kernel.cc | 5 ++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 08756a91c82..5cd480f425a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2172,10 +2172,15 @@ void MomentumInferMeta(const MetaTensor& param, auto param_dim = param.dims(); param_out->set_dims(param_dim); + auto MPType = (param.dtype() == phi::DataType::FLOAT16 || + param.dtype() == phi::DataType::BFLOAT16) + ? phi::DataType::FLOAT32 + : param.dtype(); velocity_out->set_dims(param_dim); - + velocity_out->set_dtype(MPType); if (master_param_out) { master_param_out->set_dims(param_dim); + master_param_out->set_dtype(MPType); } } diff --git a/paddle/phi/kernels/gpu/momentum_kernel.cu b/paddle/phi/kernels/gpu/momentum_kernel.cu index 5a4f5d33e61..6d2b51dff64 100644 --- a/paddle/phi/kernels/gpu/momentum_kernel.cu +++ b/paddle/phi/kernels/gpu/momentum_kernel.cu @@ -24,7 +24,10 @@ PD_REGISTER_KERNEL(momentum, phi::MomentumDenseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, GPU, @@ -32,4 +35,7 @@ PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad, phi::MomentumSparseKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/xpu/momentum_kernel.cc b/paddle/phi/kernels/xpu/momentum_kernel.cc index ad9cb2e6ef8..207bfef37f9 100644 --- a/paddle/phi/kernels/xpu/momentum_kernel.cc +++ b/paddle/phi/kernels/xpu/momentum_kernel.cc @@ -69,4 +69,7 @@ PD_REGISTER_KERNEL(momentum, ALL_LAYOUT, phi::MomentumDenseKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); +} -- GitLab