// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "paddle/phi/kernels/merged_momentum_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { template void MergedMomentumKernel( const Context& dev_ctx, const std::vector& params, const std::vector& grad, const std::vector& velocity, const std::vector& learning_rate, const paddle::optional>& master_param, float mu_in, bool use_nesterov, const std::vector& regularization_method, const std::vector& regularization_coeff, bool multi_precision, float rescale_grad, std::vector params_out, std::vector velocity_out, std::vector master_param_out) { using XPUType = typename XPUTypeTrait::Type; auto lr = learning_rate[0]; T mu = static_cast(mu_in); int op_num = params.size(); PADDLE_ENFORCE_EQ(op_num, params_out.size(), errors::InvalidArgument( "The size of Output(ParamOut) must be equal to " "Input(Param), but got the size of Output(ParamOut) " "is %d, the size of Input(Param) is %d.", params_out.size(), op_num)); PADDLE_ENFORCE_EQ(op_num, velocity.size(), errors::InvalidArgument( "The size of Output(Velocity) must be equal to " "Input(Param), but got the size of Output(Velocity) " "is %d, the size of Input(Param) is %d.", velocity.size(), op_num)); PADDLE_ENFORCE_EQ(op_num, velocity_out.size(), errors::InvalidArgument( "The size of Output(VelocityOut) must be equal to " "Input(Param), but got the size of Output(VelocityOut) " "is %d, the size of Input(Param) is %d.", velocity_out.size(), op_num)); PADDLE_ENFORCE_EQ( op_num, grad.size(), errors::InvalidArgument( "The size of Input(Grad) must be equal to Input(Param), but got " "the size of Input(Grad) is %d, the size of Input(Param) is %d.", grad.size(), op_num)); std::vector param_list(op_num); std::vector velocity_list(op_num); std::vector grad_list(op_num); std::vector velocity_out_list(op_num); std::vector param_out_list(op_num); std::vector sizes(op_num); std::vector l2_weight_decay(op_num); if (op_num > 0) { for (int j = 0; j < op_num; j++) { param_list[j] = reinterpret_cast(const_cast(params[j]->data())); velocity_list[j] = reinterpret_cast(const_cast(velocity[j]->data())); grad_list[j] = reinterpret_cast(const_cast(grad[j]->data())); param_out_list[j] = reinterpret_cast(params_out[j]->data()); velocity_out_list[j] = reinterpret_cast(velocity_out[j]->data()); sizes[j] = static_cast(params[j]->numel()); if (regularization_method[j] != "l2_decay") { l2_weight_decay[j] = 0.0f; } else { l2_weight_decay[j] = static_cast(regularization_coeff[j]); } PADDLE_ENFORCE_EQ(params[j], params_out[j], errors::InvalidArgument( "The size of Input(Param) and Output(ParamOut) " "must be the same Tensors.")); PADDLE_ENFORCE_EQ(velocity[j], velocity_out[j], errors::InvalidArgument( "The size of Input(velocity) and Output(velocity) " "must be the same Tensors.")); } } else { return; } int r = xpu::merged_momentum(dev_ctx.x_context(), param_list, velocity_list, grad_list, param_out_list, velocity_out_list, l2_weight_decay, sizes, lr->data(), mu, use_nesterov); PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_momentum"); } } // namespace phi PD_REGISTER_KERNEL(merged_momentum, XPU, ALL_LAYOUT, phi::MergedMomentumKernel, float, phi::dtype::float16) {}