// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/fused_adam_kernel.h" #include #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/adam_kernel.h" #include "paddle/phi/kernels/adamw_kernel.h" namespace phi { static paddle::optional TensorPtrToOptionalTensor( const paddle::optional>& t, size_t idx) { return t ? paddle::optional(*(t.get()[idx])) : paddle::none; } template void FusedAdamKernel( const Context& dev_ctx, const std::vector& params, const std::vector& grads, const DenseTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, const paddle::optional& skip_update, const Scalar& beta1, const Scalar& beta2, const Scalar& epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow, std::vector params_out, std::vector moments1_out, std::vector moments2_out, std::vector beta1_pows_out, std::vector beta2_pows_out, std::vector master_params_out) { size_t params_num = params.size(); PADDLE_ENFORCE_EQ( params_num, grads.size(), errors::InvalidArgument("The size of Input(grads) must be equal to " "Input(params), but got the size of Input(grads) " "is %d, the size of Input(params) is %d.", grads.size(), params_num)); PADDLE_ENFORCE_EQ(params_num, moments1.size(), errors::InvalidArgument( "The size of Input(moments1) must be equal to " "Input(params), but got the size of Input(moments1) " "is %d, the size of Input(params) is %d.", moments1.size(), params_num)); PADDLE_ENFORCE_EQ(params_num, moments2.size(), errors::InvalidArgument( "The size of Input(moments2) must be equal to " "Input(params), but got the size of Input(moments2) " "is %d, the size of Input(params) is %d.", moments2.size(), params_num)); PADDLE_ENFORCE_EQ(params_num, beta1_pows.size(), errors::InvalidArgument( "The size of Input(beta1_pows) must be equal to " "Input(params), but got the size of Input(beta1_pows) " "is %d, the size of Input(params) is %d.", beta1_pows.size(), params_num)); PADDLE_ENFORCE_EQ(params_num, beta2_pows.size(), errors::InvalidArgument( "The size of Input(beta2_pows) must be equal to " "Input(params), but got the size of Input(beta2_pows) " "is %d, the size of Input(params) is %d.", beta2_pows.size(), params_num)); for (size_t idx = 0; idx < params_num; idx++) { auto master_params_tmp = TensorPtrToOptionalTensor(master_params, idx); if (!use_adamw) { AdamDenseKernel( dev_ctx, *params[idx], *grads[idx], learning_rate, *moments1[idx], *moments2[idx], *beta1_pows[idx], *beta2_pows[idx], master_params_tmp, skip_update, beta1, beta2, epsilon, false, 1000, multi_precision, use_global_beta_pow, params_out[idx], moments1_out[idx], moments2_out[idx], beta1_pows_out[idx], beta2_pows_out[idx], master_params_out.empty() ? nullptr : master_params_out[idx]); } else { AdamwDenseKernel( dev_ctx, *params[idx], *grads[idx], learning_rate, *moments1[idx], *moments2[idx], *beta1_pows[idx], *beta2_pows[idx], master_params_tmp, skip_update, beta1, beta2, epsilon, 1.0, weight_decay, use_adamw, false, 1000, multi_precision, use_global_beta_pow, params_out[idx], moments1_out[idx], moments2_out[idx], beta1_pows_out[idx], beta2_pows_out[idx], master_params_out.empty() ? nullptr : master_params_out[idx]); } } } } // namespace phi PD_REGISTER_KERNEL( fused_adam, CPU, ALL_LAYOUT, phi::FusedAdamKernel, float, double) {}