From 7541579ad6b79fd0c6b6f0ba4d30957a6b021767 Mon Sep 17 00:00:00 2001 From: Chenxiao Niu Date: Tue, 11 Oct 2022 19:03:39 +0800 Subject: [PATCH] [MLU] add masterparam support for mlu adamw. (#46804) --- .../fluid/operators/optimizers/adam_op_mlu.cc | 53 +++++++++++++++++-- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op_mlu.cc b/paddle/fluid/operators/optimizers/adam_op_mlu.cc index 88e94b32cc..80743606c7 100644 --- a/paddle/fluid/operators/optimizers/adam_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_mlu.cc @@ -291,11 +291,38 @@ class AdamWMLUKernel : public AdamMLUKernel { skip_update = skip_update_vec[0]; } bool with_decay = ctx.Attr("with_decay"); + const bool multi_precision = ctx.Attr("multi_precision"); + auto* param_out = ctx.Output("ParamOut"); + auto* master_param_out = ctx.Output("MasterParamOut"); + const auto* master_param = ctx.Input("MasterParam"); + VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay; if (!skip_update && with_decay) { - if (ctx.HasInput("MasterParam")) { - PADDLE_THROW(platform::errors::Unimplemented( - "Master Param is not supported on MLU")); + auto* param = ctx.Input("Param"); + MLUCnnlTensorDesc param_desc(*param); + if (multi_precision) { + VLOG(3) << "[adamw] multi_precision, cast masterparam to param."; + bool has_master = + ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + PADDLE_ENFORCE_EQ( + has_master, + true, + platform::errors::InvalidArgument( + "The Input(MasterParam) and Output(MasterParamOut) " + "should not be null when " + "the attr `multi_precision` is true")); + // cast masterparam (fp32) to param (fp16), then paramout (fp16) to + // masterparamout (fp32) + MLUCnnlTensorDesc master_param_desc(*master_param); + cnnlCastDataType_t cast_type = GetCastDataType( + framework::TransToProtoVarType(master_param->dtype()), + framework::TransToProtoVarType(param->dtype())); + MLUCnnl::Cast(ctx, + cast_type, + master_param_desc.get(), + GetBasePtr(master_param), + param_desc.get(), + const_cast(GetBasePtr(param))); } else { const auto* param_var = ctx.InputVar("Param"); PADDLE_ENFORCE_EQ(param_var->IsType(), @@ -305,13 +332,12 @@ class AdamWMLUKernel : public AdamMLUKernel { "but the received is %s", ctx.InputNames("Param").front(), framework::ToTypeName(param_var->Type()))); - auto* param = ctx.Input("Param"); + auto* lr = ctx.Input("LearningRate"); float coeff = ctx.Attr("coeff"); // update param with decay coeff: mul(-1 * lr, coeff * param) + param MLUCnnlTensorDesc lr_desc(*lr); - MLUCnnlTensorDesc param_desc(*param); MLUCnnlOpTensorDesc mul_op_desc( CNNL_OP_TENSOR_MUL, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN); @@ -330,6 +356,23 @@ class AdamWMLUKernel : public AdamMLUKernel { } } AdamMLUKernel::Compute(ctx); + if (multi_precision) { + VLOG(3) << "[adamw] multi_precision, cast paramout to masterparamout."; + // cast paramout to masterparamout + master_param_out->mutable_data(ctx.GetPlace()); + cnnlCastDataType_t cast_type = GetCastDataType( + framework::TransToProtoVarType(param_out->dtype()), + framework::TransToProtoVarType(master_param_out->dtype())); + MLUCnnlTensorDesc param_out_desc(*param_out); + MLUCnnlTensorDesc master_param_out_desc(*master_param_out); + + MLUCnnl::Cast(ctx, + cast_type, + param_out_desc.get(), + GetBasePtr(param_out), + master_param_out_desc.get(), + GetBasePtr(master_param_out)); + } } }; -- GitLab