// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/macros.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h" #include "paddle/phi/kernels/merged_momentum_kernel.h" namespace phi { template using MultiPrecisionType = typename phi::dtype::MPTypeTrait::Type; template struct MergedMomentumMasterParams { MT *PADDLE_RESTRICT master_params[kParamNum]; HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; } HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; } }; template struct MergedMomentumMasterParams { HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; } HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {} }; template struct MergedMomentumKernelParam : public MergedMomentumMasterParams { static constexpr auto N = kParamNum; size_t sizes[N]; T *PADDLE_RESTRICT params[N]; const T *PADDLE_RESTRICT grads[N]; MT *PADDLE_RESTRICT velocitys[N]; const MultiPrecisionType *PADDLE_RESTRICT lr; MT mu; MT rescale_grad; uint32_t param_num; HOSTDEVICE void operator()(size_t i) const { const MT lr_val = static_cast(*lr); for (uint32_t idx = 0; idx < param_num; ++idx) { auto size = sizes[idx]; if (i >= size) continue; auto param_p = params[idx]; auto grad_p = grads[idx]; auto velocity_p = velocitys[idx]; auto master_param_p = this->MasterParam(idx); const MT param = master_param_p ? master_param_p[i] : static_cast(param_p[i]); const MT grad = static_cast(grad_p[i]) * rescale_grad; const MT velocity = velocity_p[i]; const MT velocity_out = velocity * mu + grad; const MT param_out = param - lr_val * velocity_out; velocity_p[i] = velocity_out; param_p[i] = static_cast(param_out); if (master_param_p) { master_param_p[i] = param_out; } } } }; template void MergedMomentumInnerCompute( const Context &ctx, const std::vector ¶ms, const std::vector &grads, const std::vector &velocitys, const std::vector &lrs, const paddle::optional> &master_params_opt, float mu, bool use_nesterov, const std::vector ®ularization_methods, const std::vector ®ularization_coeffs, float rescale_grad, const bool multi_precision, std::vector params_out, std::vector velocitys_out, std::vector master_params_out) { size_t n = params.size(); PADDLE_ENFORCE_EQ(n, params_out.size(), phi::errors::InvalidArgument( "The size of Output(ParamOut) must be equal to " "Input(Param), but got the size of Output(ParamOut) " "is %d, the size of Input(Param) is %d.", params_out.size(), n)); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ( params[i], params_out[i], phi::errors::InvalidArgument("Input(Param) and Output(ParamOut) " "must be the same Tensors.")); } PADDLE_ENFORCE_EQ( n, grads.size(), phi::errors::InvalidArgument( "The size of Input(Grad) must be equal to Input(Param), but got " "the size of Input(Grad) is %d, the size of Input(Param) is %d.", grads.size(), n)); PADDLE_ENFORCE_EQ(n, velocitys.size(), phi::errors::InvalidArgument( "The size of Input(Velocity) must be equal to " "Input(Param), but got the size of Input(Velocity) " "is %d, the size of Input(Param) is %d.", velocitys.size(), n)); PADDLE_ENFORCE_EQ( n, velocitys_out.size(), phi::errors::InvalidArgument( "The size of Output(VelocityOut) must be " "equal to Input(Param), but got the size of Output(VelocityOut) is " "%d, the size of Input(Param) is %d.", velocitys_out.size(), n)); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i], phi::errors::InvalidArgument( "Input(Velocity) and Output(VelocityOut) must be " "the same Tensors.")); } if (multi_precision) { auto master_params = master_params_opt.get(); PADDLE_ENFORCE_EQ( n, master_params.size(), phi::errors::InvalidArgument( "The size of Input(MasterParam) must be " "equal to Input(Param), but got the size of Input(MasterParam) " "is %d, the size of Input(Param) is %d.", master_params.size(), n)); PADDLE_ENFORCE_EQ( n, master_params_out.size(), phi::errors::InvalidArgument( "The size of Output(MasterParamOut) must be equal to " "Input(MasterParam), but got the size of Output(MasterParamOut) " "is %d, the size of Input(Param) is %d.", master_params_out.size(), n)); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i], phi::errors::InvalidArgument( "Input(MasterParam) and Output(MasterParamOut) " "must be the same Tensors.")); PADDLE_ENFORCE_NOT_NULL(master_params[i], phi::errors::InvalidArgument( "Input(MasterParam) must be provided when " "multi_precision=True.")); } } else { master_params_out.clear(); } if (lrs.size() != 1) { PADDLE_ENFORCE_EQ( n, lrs.size(), phi::errors::InvalidArgument( "If the size of Input(LearningRate) is not 1, the size of " "Input(LearningRate) must be " "equal to Input(Param), but got the size of Input(LearningRate) " "is %d, the size of Input(Param) is %d.", lrs.size(), n)); } if (regularization_methods.size() != 0) { PADDLE_ENFORCE_EQ( n, regularization_methods.size(), phi::errors::InvalidArgument( "The size of Attr(regularization_method) must be equal " "to Input(Param), but got the size of " "Attr(regularization_method) is %d, the size of Input(Param) is " "%d.", regularization_methods.size(), n)); PADDLE_ENFORCE_EQ( n, regularization_coeffs.size(), phi::errors::InvalidArgument( "The size of Attr(regularization_coeff) must be equal " "to Input(Param), but got the size of Attr(regularization_coeff) " "is %d, the size of Input(Param) is %d.", regularization_coeffs.size(), n)); } VLOG(5) << "use_nesterov: " << use_nesterov << ", regularization_methods.size(): " << regularization_methods.size() << ", regularization_coeffs.size(): " << regularization_coeffs.size(); if (lrs.size() == 1 && use_nesterov == false && regularization_methods.size() == 0) { #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ MergedMomentumKernelParam kernel_params; \ constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ kernel_params.mu = static_cast(mu); \ kernel_params.rescale_grad = static_cast(rescale_grad); \ kernel_params.lr = lrs[0]->data(); \ for (size_t i = 0; i < kernel_num; ++i) { \ size_t start = i * kMaxMergedNum; \ size_t end = std::min((i + 1) * kMaxMergedNum, n); \ kernel_params.param_num = static_cast(end - start); \ size_t max_size = 0; \ for (size_t j = 0; j < kernel_params.param_num; ++j) { \ auto size = static_cast(params_out[j + start]->numel()); \ max_size = std::max(max_size, size); \ kernel_params.sizes[j] = size; \ kernel_params.params[j] = params_out[j + start]->data(); \ kernel_params.grads[j] = grads[j + start]->data(); \ kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ kernel_params.SetMasterParam( \ j, \ kMultiPrecision ? master_params_out[j + start]->data() \ : nullptr); \ } \ phi::funcs::ForRange for_range(ctx, max_size); \ for_range(kernel_params); \ VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ << kernel_params.param_num; \ } if (multi_precision) { PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); } else { PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); } #undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL } else { for (size_t idx = 0; idx < n; idx++) { phi::RegularizationType regularization_flag = regularization_methods.size() > 0 && regularization_methods[idx] == "l2_decay" ? phi::RegularizationType::kL2DECAY : phi::RegularizationType::kNONE; MT regularization_coeff = static_cast(0.0); if (regularization_coeffs.size() != 0) { regularization_coeff = static_cast(regularization_coeffs[idx]); } auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; const MT *master_in_data = multi_precision ? master_params_opt.get()[idx]->data() : nullptr; MT *master_out_data = multi_precision ? master_params_out[idx]->data() : nullptr; if (paddle::platform::is_cpu_place(ctx.GetPlace())) { phi::CPUDenseMomentumFunctor functor; functor(params[idx], grads[idx], velocitys[idx], lr_temp, static_cast(mu), use_nesterov, regularization_flag, regularization_coeff, params_out[idx], velocitys_out[idx]); VLOG(10) << "Launch MergedMomentum cpu kernel."; } else if (paddle::platform::is_gpu_place(ctx.GetPlace())) { phi::funcs::ForRange for_range( static_cast(ctx), params[idx]->numel()); #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ phi::DenseMomentumFunctor functor( \ params[idx]->data(), \ grads[idx]->data(), \ velocitys[idx]->data(), \ lr_temp->data(), \ master_in_data, \ static_cast(mu), \ static_cast(rescale_grad), \ params[idx]->numel(), \ regularization_coeff, \ params_out[idx]->data(), \ velocitys_out[idx]->data(), \ master_out_data); \ for_range(functor); if (use_nesterov) { if (regularization_flag == phi::RegularizationType::kL2DECAY) { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( phi::UseNesterov, phi::RegularizationType::kL2DECAY); VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY."; } else { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( phi::UseNesterov, phi::RegularizationType::kNONE); VLOG(10) << "Launch MergedMomentum gpu kernel use_nesterov kNONE."; } } else { if (regularization_flag == phi::RegularizationType::kL2DECAY) { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( phi::NoNesterov, phi::RegularizationType::kL2DECAY); VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY."; } else { PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL( phi::NoNesterov, phi::RegularizationType::kNONE); VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE."; } } } } VLOG(10) << "Launch MergedMomentum kernel with multi_lr and regularization."; } } template void MergedMomentumKernel( const Context &dev_ctx, const std::vector ¶m, const std::vector &grad, const std::vector &velocity, const std::vector &learning_rate, const paddle::optional> &master_param, float mu, bool use_nesterov, const std::vector ®ularization_method, const std::vector ®ularization_coeff, bool multi_precision, float rescale_grad, std::vector param_out, std::vector velocity_out, std::vector master_param_out) { using MPType = typename phi::dtype::MPTypeTrait::Type; if (multi_precision) { MergedMomentumInnerCompute( dev_ctx, param, grad, velocity, learning_rate, master_param, mu, use_nesterov, regularization_method, regularization_coeff, rescale_grad, multi_precision, param_out, velocity_out, master_param_out); } else { MergedMomentumInnerCompute(dev_ctx, param, grad, velocity, learning_rate, master_param, mu, use_nesterov, regularization_method, regularization_coeff, rescale_grad, multi_precision, param_out, velocity_out, master_param_out); } } } // namespace phi