From c0f84b8f39186d9f15eaa0c40e3c43a15de6a9ef Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Thu, 9 Mar 2023 14:13:30 +0800 Subject: [PATCH] Add output defs for sgd kernel (#51332) * Add output defs for sgd kernel * add datatype infer for sgd * add infer logic --- .../new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/infermeta/multiary.cc | 9 +++++++++ paddle/phi/kernels/gpu/sgd_kernel.cu | 4 +++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 91802f04207..cf562043b70 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -96,7 +96,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "select", "send_recv", "send_ue_recv", - "sgd", "svd", "sync_batch_norm_grad", "unique", diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 9a4f233ce8f..e22df441fd8 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2484,6 +2484,15 @@ void SgdInferMeta(const MetaTensor& param, param_out->set_dims(param.dims()); param_out->set_dtype(param.dtype()); + if (multi_precision) { + master_param_out->set_dims(master_param.dims()); + if (DataType::FLOAT16 == master_param.dtype() || + DataType::BFLOAT16 == master_param.dtype()) { + master_param_out->set_dtype(DataType::FLOAT32); + } else { + master_param_out->set_dtype(master_param.dtype()); + } + } } void SendUERecvInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/gpu/sgd_kernel.cu b/paddle/phi/kernels/gpu/sgd_kernel.cu index 73115a58fa9..d489ccb4cb2 100644 --- a/paddle/phi/kernels/gpu/sgd_kernel.cu +++ b/paddle/phi/kernels/gpu/sgd_kernel.cu @@ -187,7 +187,9 @@ PD_REGISTER_KERNEL(sgd, phi::SGDDenseKernel, phi::dtype::float16, float, - double) {} + double) { + kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); +} PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, GPU, -- GitLab