/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/optimizers/adamw_op.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { template __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, MT beta1_pow_, MT beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2, MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, int ndim) { MT lr = *lr_ * lr_ratio; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; } } } template __global__ void AdamWKernelMEM( MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT* beta1_pow_, const MT* beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2, MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, int ndim) { MT lr = *lr_ * lr_ratio; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; } } } template __global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_, const T* beta2_pow_, T* beta1_pow_out, T* beta2_pow_out) { *beta1_pow_out = beta1 * beta1_pow_[0]; *beta2_pow_out = beta2 * beta2_pow_[0]; } template __global__ void SparseAdamWCUDAKernelREG( MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT beta1_pow, const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_, MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_, T* param_out_, const MT* master_param, MT* master_param_out, const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) { int id = blockIdx.x * blockDim.x + threadIdx.x; MT lr = *lr_ * lr_ratio; for (; id < ndim; id += blockDim.x * gridDim.x) { auto row_idx = phi::funcs::BinarySearch(rows_, row_count, id / row_numel); if (lazy_mode && row_idx < 0) { return; } else { MT mom1 = static_cast(mom1_[id]); MT mom2 = static_cast(mom2_[id]); MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 ? static_cast(grad_[row_idx * row_numel + id % row_numel]) : static_cast(0); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); // Write back to global memory mom1_out_[id] = mom1; mom2_out_[id] = mom2; param_out_[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; } } } } template class AdamWOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const auto* param_var = ctx.InputVar("Param"); PADDLE_ENFORCE_EQ(param_var->IsType(), true, platform::errors::InvalidArgument( "The Var(%s)'s type should be LoDTensor, " "but the received is %s", ctx.InputNames("Param").front(), framework::ToTypeName(param_var->Type()))); using paddle::framework::LoDTensor; using MPDType = typename details::MPTypeTrait::Type; int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; MPDType coeff = static_cast(ctx.Attr("coeff")); MPDType lr_ratio = static_cast(ctx.Attr("lr_ratio")); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); auto* mom1 = ctx.Input("Moment1"); auto* mom2 = ctx.Input("Moment2"); auto* lr = ctx.Input("LearningRate"); auto* beta1_pow = ctx.Input("Beta1Pow"); auto* beta2_pow = ctx.Input("Beta2Pow"); auto* param_out = ctx.Output("ParamOut"); auto* mom1_out = ctx.Output("Moment1Out"); auto* mom2_out = ctx.Output("Moment2Out"); auto* beta1_pow_out = ctx.Output("Beta1PowOut"); auto* beta2_pow_out = ctx.Output("Beta2PowOut"); bool skip_update = false; if (ctx.HasInput("SkipUpdate")) { auto* skip_update_tensor = ctx.Input("SkipUpdate"); PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(SkipUpdate) size must be 1, but get %d", skip_update_tensor->numel())); std::vector skip_update_vec; paddle::framework::TensorToVector(*skip_update_tensor, ctx.device_context(), &skip_update_vec); skip_update = skip_update_vec[0]; } // skip_update=true, just copy input to output, and TensorCopy will call // mutable_data if (skip_update) { VLOG(4) << "Adamw skip update"; framework::TensorCopy( *param, ctx.GetPlace(), ctx.template device_context(), param_out); framework::TensorCopy( *mom1, ctx.GetPlace(), ctx.template device_context(), mom1_out); framework::TensorCopy( *mom2, ctx.GetPlace(), ctx.template device_context(), mom2_out); framework::TensorCopy( *beta1_pow, ctx.GetPlace(), ctx.template device_context(), beta1_pow_out); framework::TensorCopy( *beta2_pow, ctx.GetPlace(), ctx.template device_context(), beta2_pow_out); return; } // if with_decay = false, coeff = 0 bool with_decay = ctx.Attr("with_decay"); if (!with_decay) { coeff = static_cast(0.0); } MPDType beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { auto* beta1_tensor = ctx.Input("Beta1Tensor"); PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(Beta1Tensor) size must be 1, but get %d", beta1_tensor->numel())); beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); } MPDType beta2 = static_cast(ctx.Attr("beta2")); if (ctx.HasInput("Beta2Tensor")) { auto* beta2_tensor = ctx.Input("Beta2Tensor"); PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(Beta2Tensor) size must be 1, but get %d", beta2_tensor->numel())); beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); } MPDType epsilon = static_cast(ctx.Attr("epsilon")); if (ctx.HasInput("EpsilonTensor")) { auto* epsilon_tensor = ctx.Input("EpsilonTensor"); PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(EpsilonTensor) size must be 1, but get %d", epsilon_tensor->numel())); epsilon = static_cast(GetAttrFromTensor(epsilon_tensor)); } VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() << "beta2_pow.numel() : " << beta2_pow->numel(); VLOG(3) << "param.numel(): " << param->numel(); PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, platform::errors::InvalidArgument( "beta1 pow output size should be 1, but received " "value is:%d.", beta1_pow_out->numel())); PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, platform::errors::InvalidArgument( "beta2 pow output size should be 1, but received " "value is:%d.", beta2_pow_out->numel())); const bool multi_precision = ctx.Attr("multi_precision"); const LoDTensor* master_param = nullptr; LoDTensor* master_param_out = nullptr; if (multi_precision) { bool has_master = ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); PADDLE_ENFORCE_EQ(has_master, true, platform::errors::InvalidArgument( "The Input(MasterParam) and Output(MasterParamOut) " "should not be null when " "the attr `multi_precision` is true")); master_param = ctx.Input("MasterParam"); master_param_out = ctx.Output("MasterParamOut"); } const MPDType* master_in_data = multi_precision ? master_param->data() : nullptr; MPDType* master_out_data = multi_precision ? master_param_out->mutable_data(ctx.GetPlace()) : nullptr; auto& dev_ctx = ctx.template device_context(); if (grad_var->IsType()) { auto* grad = ctx.Input("Grad"); // update param and moment int threads = 512; int blocks = (param->numel() + threads - 1) / threads; if (beta1_pow->place() == platform::CPUPlace() && beta2_pow->place() == platform::CPUPlace()) { // Compute with betapow in REG AdamWKernelREG<<>>( beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data(), *beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad->data(), param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, param->numel()); if (!use_global_beta_pow) { // Cpu update beta1_pow_out->mutable_data(platform::CPUPlace())[0] = beta1 * beta1_pow->data()[0]; beta2_pow_out->mutable_data(platform::CPUPlace())[0] = beta2 * beta2_pow->data()[0]; } } else { AdamWKernelMEM<<>>( beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data(), beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad->data(), param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, param->numel()); if (!use_global_beta_pow) { // Update with gpu UpdateAdamWBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( beta1, beta2, beta1_pow->data(), beta2_pow->data(), beta1_pow_out->mutable_data(ctx.GetPlace()), beta2_pow_out->mutable_data(ctx.GetPlace())); } } } else if (grad_var->IsType()) { auto* grad = ctx.Input("Grad"); if (grad->rows().size() == 0) { VLOG(3) << "grad row size is 0!!"; return; } std::vector cpu_rows(grad->rows().begin(), grad->rows().end()); bool is_strict_sorted = true; for (size_t i = 1; i < cpu_rows.size(); ++i) { if (cpu_rows[i - 1] >= cpu_rows[i]) { is_strict_sorted = false; break; } } phi::SelectedRows tmp_grad_merge; const phi::SelectedRows* grad_merge_ptr; if (is_strict_sorted) { grad_merge_ptr = grad; } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor scatter::MergeAdd merge_func; merge_func(ctx.template device_context(), *grad, &tmp_grad_merge, true); grad_merge_ptr = &tmp_grad_merge; } auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); auto* grad_merge_rows = &grad_merge.rows(); paddle::framework::MixVector mixv_grad_merge_rows( grad_merge_rows); const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); if (beta1_pow->place() == platform::CPUPlace() && beta2_pow->place() == platform::CPUPlace()) { int threads = 512; int ndim = param->numel(); int blocks = (ndim + threads - 1) / threads; SparseAdamWCUDAKernelREG< T, MPDType><<>>( beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data(), *beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad_data, param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, rows, row_numel, grad_merge.rows().size(), lazy_mode, ndim); if (!use_global_beta_pow) { // Update with cpu beta1_pow_out->mutable_data(platform::CPUPlace())[0] = beta1 * beta1_pow->data()[0]; beta2_pow_out->mutable_data(platform::CPUPlace())[0] = beta2 * beta2_pow->data()[0]; } } else { SparseAdamWFunctor functor( beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data(), beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad_data, param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, rows, row_numel, grad_merge.rows().size(), lazy_mode); // FIXME(minqiyang): remove BinarySearch in GPU later platform::ForRange for_range( static_cast( ctx.device_context()), param->numel()); for_range(functor); if (!use_global_beta_pow) { // update beta1 and beta2 UpdateAdamWBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( beta1, beta2, beta1_pow->data(), beta2_pow->data(), beta1_pow_out->mutable_data(ctx.GetPlace()), beta2_pow_out->mutable_data(ctx.GetPlace())); } } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Variable type not supported by adamw_op")); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(adamw, ops::AdamWOpCUDAKernel, ops::AdamWOpCUDAKernel, ops::AdamWOpCUDAKernel);