// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/lars_momentum_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { template void LarsMomentumKernel( const Context& dev_ctx, const std::vector& param, const std::vector& velocity, const std::vector& learning_rate, const std::vector& grad, const paddle::optional>& master_param, const std::vector& weight_decay_arr, float mu, float lars_coeff, float epsilon, bool multi_precision, float rescale_grad, std::vector param_out, std::vector velocity_out, std::vector master_param_out) { int op_num = static_cast(param.size()); T mu_ = static_cast(mu); for (int i = 0; i < op_num; ++i) { auto* lr = learning_rate[i]->data(); T lars_weight_decay = weight_decay_arr[i]; dev_ctx.template Alloc(param_out[i]); dev_ctx.template Alloc(velocity_out[i]); auto p_out = phi::EigenVector::Flatten(*(param_out[i])); auto v_out = phi::EigenVector::Flatten(*(velocity_out[i])); auto p = phi::EigenVector::Flatten(*(param[i])); auto v = phi::EigenVector::Flatten(*(velocity[i])); Eigen::TensorMap> g = phi::EigenVector::Flatten(*(grad[i])); auto rescale_g = static_cast(rescale_grad) * g; phi::DenseTensor p_norm_t, g_norm_t; p_norm_t.Resize({1}); g_norm_t.Resize({1}); dev_ctx.template Alloc(&p_norm_t); dev_ctx.template Alloc(&g_norm_t); auto ep_norm = phi::EigenScalar::From(p_norm_t); auto eg_norm = phi::EigenScalar::From(g_norm_t); ep_norm = p.square().sum().sqrt(); eg_norm = rescale_g.square().sum().sqrt(); T local_lr = lr[0]; if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { local_lr = lr[0] * lars_coeff * ep_norm(0) / (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); } v_out = v * mu_ + local_lr * (rescale_g + lars_weight_decay * p); p_out = p - v_out; } } } // namespace phi PD_REGISTER_KERNEL( lars_momentum, CPU, ALL_LAYOUT, phi::LarsMomentumKernel, float, double) {}