// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/kernels/funcs/lamb_functors.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h" namespace phi { namespace sr { template void ComputeRowImpl(const Context& dev_ctx, const DenseTensor& param, const SelectedRows& grad, const DenseTensor& lr, const DenseTensor& mom1, const DenseTensor& mom2, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param_opt, const paddle::optional& skip_update_opt, float weight_decay_f, float beta1_f, float beta2_f, float epsilon_f, bool multi_precision, DenseTensor* param_out, DenseTensor* mom1_out, DenseTensor* mom2_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_out); template void LambKernel(const Context& dev_ctx, const DenseTensor& param, const SelectedRows& grad, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, const paddle::optional& skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { using MT = typename phi::dtype::MPTypeTrait::Type; if (multi_precision) { ComputeRowImpl(dev_ctx, param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, weight_decay, beta1, beta2, epsilon, multi_precision, param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out, master_param_outs); } else { ComputeRowImpl(dev_ctx, param, grad, learning_rate, moment1, moment2, beta1_pow, beta2_pow, master_param, skip_update, weight_decay, beta1, beta2, epsilon, multi_precision, param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out, master_param_outs); } } template void ComputeRowImpl(const Context& dev_ctx, const DenseTensor& param, const SelectedRows& grad, const DenseTensor& lr, const DenseTensor& mom1, const DenseTensor& mom2, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param_opt, const paddle::optional& skip_update_opt, float weight_decay_f, float beta1_f, float beta2_f, float epsilon_f, bool multi_precision, DenseTensor* param_out, DenseTensor* mom1_out, DenseTensor* mom2_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_out) { if (!IsMultiPrecision) { constexpr auto kIsSameType = std::is_same::value; PADDLE_ENFORCE_EQ( kIsSameType, true, phi::errors::InvalidArgument( "When multi_precision=False, T and MT must be the same type.")); } const auto* master_param = IsMultiPrecision ? master_param_opt.get_ptr() : nullptr; const auto* skip_update = skip_update_opt.get_ptr(); const bool* skip_update_flag = skip_update && skip_update->IsInitialized() ? skip_update->data() : nullptr; if (skip_update_flag && skip_update->place().GetType() == phi::AllocationType::CPU && (*skip_update_flag)) { return; } auto weight_decay = static_cast(weight_decay_f); auto beta1 = static_cast(beta1_f); auto beta2 = static_cast(beta2_f); auto epsilon = static_cast(epsilon_f); auto numel = param.numel(); phi::funcs::ForRange for_range(dev_ctx, numel); DenseTensor trust_ratio_div; trust_ratio_div.Resize(param.dims()); /*auto trust_ratio_div = ctx.AllocateTmpTensor(param.dims(), dev_ctx);*/ auto* trust_ratio_div_ptr = dev_ctx.template Alloc(&trust_ratio_div); const void* param_ptr = param.data(); const void* master_param_ptr = master_param ? master_param->data() : nullptr; void* param_out_ptr = dev_ctx.template Alloc(param_out); void* master_param_out_ptr = master_param_out ? dev_ctx.template Alloc(master_param_out) : nullptr; // Update moments bool should_update_beta_pow_later = false; const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr; MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr; VLOG(10) << "Beta1Pow place: " << beta1_pow.place() << " , Beta2Pow place: " << beta2_pow.place(); // Diff from here PADDLE_ENFORCE_EQ( IsMultiPrecision, false, phi::errors::Unimplemented("SelectedRows gradient is not supported when " "multi_precision=True.")); constexpr bool kIsSameType = std::is_same::value; PADDLE_ENFORCE_EQ( kIsSameType, true, phi::errors::Unimplemented("SelectedRows gradient is not supported when " "multi_precision=True.")); if (grad.rows().size() == 0) { VLOG(3) << "grad row size is 0!!"; return; } std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); bool is_strict_sorted = true; for (size_t i = 1; i < cpu_rows.size(); ++i) { if (cpu_rows[i - 1] >= cpu_rows[i]) { is_strict_sorted = false; break; } } phi::SelectedRows tmp_grad_merge; const phi::SelectedRows* grad_merge_ptr; if (is_strict_sorted) { grad_merge_ptr = &grad; } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor phi::funcs::scatter::MergeAdd merge_func; merge_func(dev_ctx, grad, &tmp_grad_merge, true); grad_merge_ptr = &tmp_grad_merge; } auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); auto* grad_merge_rows = &grad_merge.rows(); phi::MixVector mixv_grad_merge_rows(grad_merge_rows); const int64_t* rows = mixv_grad_merge_rows.Data(dev_ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU && beta1_pow.place() == phi::CPUPlace() && beta2_pow.place() == phi::CPUPlace()) { SparseLambMomentREGUpdateFunctor moment_update_functor( static_cast(weight_decay), static_cast(beta1), static_cast(beta2), static_cast(epsilon), *beta1_pow.template data(), *beta2_pow.template data(), mom1.template data(), dev_ctx.template Alloc(mom1_out), mom2.template data(), dev_ctx.template Alloc(mom2_out), grad_data, param.template data(), trust_ratio_div.template data(), rows, row_numel, grad_merge.rows().size(), skip_update_flag); for_range(moment_update_functor); T* beta1_pow_out_data = dev_ctx.template HostAlloc(beta1_pow_out); beta1_pow_out_data[0] = static_cast(beta1) * beta1_pow.template data()[0]; T* beta2_pow_out_data = dev_ctx.template HostAlloc(beta2_pow_out); beta2_pow_out_data[0] = static_cast(beta2) * beta2_pow.template data()[0]; } else { beta1_pow_ptr = beta1_pow.template data(); beta2_pow_ptr = beta2_pow.template data(); beta1_pow_out_ptr = dev_ctx.template Alloc(beta1_pow_out); beta2_pow_out_ptr = dev_ctx.template Alloc(beta2_pow_out); should_update_beta_pow_later = true; SparseLambMomentMENUpdateFunctor moment_update_functor( static_cast(weight_decay), static_cast(beta1), static_cast(beta2), static_cast(epsilon), reinterpret_cast(beta1_pow_ptr), reinterpret_cast(beta2_pow_ptr), mom1.template data(), dev_ctx.template Alloc(mom1_out), mom2.template data(), dev_ctx.template Alloc(mom2_out), grad_data, param.template data(), trust_ratio_div.template data(), rows, row_numel, grad_merge.rows().size(), skip_update_flag); for_range(moment_update_functor); } // Same from here // Update parameter // The code in the following part is exactly the same as that in // paddle/phi/kernels/impl/lamb_kernel_impl.h Please modify it together DenseTensor p_norm_t; p_norm_t.Resize(phi::make_ddim({1})); auto* p_norm_ptr = dev_ctx.template Alloc(&p_norm_t); DenseTensor trust_ratio_div_norm_t; trust_ratio_div_norm_t.Resize(phi::make_ddim({1})); auto* trust_ratio_div_norm_ptr = dev_ctx.template Alloc(&trust_ratio_div_norm_t); // TODO(zengjinle): remove the following Eigen operations when // *skip_update == true. memory_utils::Buffer buffer(dev_ctx.GetPlace()); phi::funcs::SquaredL2Norm( dev_ctx, reinterpret_cast(IsMultiPrecision ? master_param_ptr : param_ptr), p_norm_ptr, numel, &buffer); phi::funcs::SquaredL2Norm( dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); if (VLOG_IS_ON(1)) { const auto& name = "Param"; auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace()); auto tn = phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace()); auto dtype = DataTypeToString(phi::CppTypeToDataType::Type()); VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0] << " , tn = " << tn[0]; } #define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \ do { \ LambParamUpateFunctor \ param_update_functor(lr.template data(), \ static_cast(param_ptr), \ static_cast(master_param_ptr), \ p_norm_ptr, \ trust_ratio_div_ptr, \ trust_ratio_div_norm_ptr, \ static_cast(param_out_ptr), \ static_cast(master_param_out_ptr), \ skip_update_flag); \ if (__should_update_beta_pow) { \ param_update_functor.SetBetaPows(beta1_pow_ptr, \ beta2_pow_ptr, \ beta1_pow_out_ptr, \ beta2_pow_out_ptr, \ beta1, \ beta2); \ } \ for_range(param_update_functor); \ } while (0) if (should_update_beta_pow_later) { CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true); } else { CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false); } #undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC } } // namespace sr } // namespace phi