未验证 提交 7541579a 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] add masterparam support for mlu adamw. (#46804)

上级 28ef0fff
......@@ -291,11 +291,38 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
skip_update = skip_update_vec[0];
}
bool with_decay = ctx.Attr<bool>("with_decay");
const bool multi_precision = ctx.Attr<bool>("multi_precision");
auto* param_out = ctx.Output<LoDTensor>("ParamOut");
auto* master_param_out = ctx.Output<LoDTensor>("MasterParamOut");
const auto* master_param = ctx.Input<LoDTensor>("MasterParam");
VLOG(3) << "Skip update: " << skip_update << ", With decay: " << with_decay;
if (!skip_update && with_decay) {
if (ctx.HasInput("MasterParam")) {
PADDLE_THROW(platform::errors::Unimplemented(
"Master Param is not supported on MLU"));
auto* param = ctx.Input<LoDTensor>("Param");
MLUCnnlTensorDesc param_desc(*param);
if (multi_precision) {
VLOG(3) << "[adamw] multi_precision, cast masterparam to param.";
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(
has_master,
true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
// cast masterparam (fp32) to param (fp16), then paramout (fp16) to
// masterparamout (fp32)
MLUCnnlTensorDesc master_param_desc(*master_param);
cnnlCastDataType_t cast_type = GetCastDataType(
framework::TransToProtoVarType(master_param->dtype()),
framework::TransToProtoVarType(param->dtype()));
MLUCnnl::Cast(ctx,
cast_type,
master_param_desc.get(),
GetBasePtr(master_param),
param_desc.get(),
const_cast<void*>(GetBasePtr(param)));
} else {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<phi::DenseTensor>(),
......@@ -305,13 +332,12 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
auto* param = ctx.Input<LoDTensor>("Param");
auto* lr = ctx.Input<LoDTensor>("LearningRate");
float coeff = ctx.Attr<float>("coeff");
// update param with decay coeff: mul(-1 * lr, coeff * param) + param
MLUCnnlTensorDesc lr_desc(*lr);
MLUCnnlTensorDesc param_desc(*param);
MLUCnnlOpTensorDesc mul_op_desc(
CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(), CNNL_NOT_PROPAGATE_NAN);
......@@ -330,6 +356,23 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
}
}
AdamMLUKernel<T>::Compute(ctx);
if (multi_precision) {
VLOG(3) << "[adamw] multi_precision, cast paramout to masterparamout.";
// cast paramout to masterparamout
master_param_out->mutable_data<float>(ctx.GetPlace());
cnnlCastDataType_t cast_type = GetCastDataType(
framework::TransToProtoVarType(param_out->dtype()),
framework::TransToProtoVarType(master_param_out->dtype()));
MLUCnnlTensorDesc param_out_desc(*param_out);
MLUCnnlTensorDesc master_param_out_desc(*master_param_out);
MLUCnnl::Cast(ctx,
cast_type,
param_out_desc.get(),
GetBasePtr(param_out),
master_param_out_desc.get(),
GetBasePtr(master_param_out));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册