// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/lamb_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { template void LambKernel(const Context& dev_ctx, const DenseTensor& param, const DenseTensor& grad, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, const paddle::optional& skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision, DenseTensor* param_outs, DenseTensor* moment1_out, DenseTensor* moment2_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { using XPUType = typename XPUTypeTrait::Type; using MT = typename phi::dtype::MPTypeTrait::Type; if (!multi_precision) { constexpr auto kIsSameType = std::is_same::value; PADDLE_ENFORCE_EQ( kIsSameType, true, phi::errors::InvalidArgument( "When multi_precision=False, T and MT must be the same type.")); } bool cpu_skip_update = false; if (skip_update && skip_update->IsInitialized()) { if (paddle::platform::is_cpu_place(skip_update->place())) { cpu_skip_update = *(skip_update->data()); } else { const bool* skip_update_flag = skip_update->data(); memory_utils::Copy(phi::CPUPlace(), static_cast(&cpu_skip_update), dev_ctx.GetPlace(), static_cast(skip_update_flag), sizeof(bool)); } } if (cpu_skip_update) { return; } // tensor --> data_ptr // inputs const XPUType* param_ptr = reinterpret_cast(param.data()); const XPUType* grad_ptr = reinterpret_cast(grad.data()); const MT* learning_rate_ptr = learning_rate.data(); const MT* moment1_ptr = moment1.data(); const MT* moment2_ptr = moment2.data(); const MT* beta1_pow_ptr = beta1_pow.data(); const MT* beta2_pow_ptr = beta2_pow.data(); const MT* master_param_ptr = nullptr; if (multi_precision) { master_param_ptr = master_param.get_ptr()->data(); } // outputs XPUType* param_outs_ptr = reinterpret_cast(dev_ctx.template Alloc(param_outs)); MT* moment1_out_ptr = dev_ctx.template Alloc(moment1_out); MT* moment2_out_ptr = dev_ctx.template Alloc(moment2_out); MT* master_param_outs_ptr = nullptr; if (multi_precision) { if (master_param_outs->numel() != master_param.get_ptr()->numel()) { master_param_outs->Resize(master_param.get_ptr()->dims()); } master_param_outs_ptr = dev_ctx.template Alloc(master_param_outs); } MT* beta1_pow_out_ptr = nullptr; MT* beta2_pow_out_ptr = nullptr; MT* beta1_pow_xpu_ptr = nullptr; MT* beta2_pow_xpu_ptr = nullptr; xpu::Context* xpu_ctx = dev_ctx.x_context(); xpu::ctx_guard RAII_GUARD(xpu_ctx); if (beta1_pow.place().GetType() == phi::AllocationType::CPU) { int r = xpu_malloc(reinterpret_cast(&beta1_pow_xpu_ptr), (beta1_pow.numel()) * sizeof(MT)); PADDLE_ENFORCE_XPU_SUCCESS(r); memory_utils::Copy(dev_ctx.GetPlace(), beta1_pow_xpu_ptr, beta1_pow.place(), beta1_pow.data(), sizeof(MT) * beta1_pow.numel()); beta1_pow_ptr = beta1_pow_xpu_ptr; beta1_pow_out_ptr = RAII_GUARD.alloc_l3_or_gm(beta1_pow_out->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(beta1_pow_out_ptr); } else { beta1_pow_out_ptr = dev_ctx.template Alloc(beta1_pow_out); } if (beta2_pow.place().GetType() == phi::AllocationType::CPU) { int r = xpu_malloc(reinterpret_cast(&beta2_pow_xpu_ptr), (beta2_pow.numel()) * sizeof(MT)); PADDLE_ENFORCE_XPU_SUCCESS(r); memory_utils::Copy(dev_ctx.GetPlace(), beta2_pow_xpu_ptr, beta2_pow.place(), beta2_pow.data(), sizeof(MT) * beta2_pow.numel()); beta2_pow_ptr = beta2_pow_xpu_ptr; beta2_pow_out_ptr = RAII_GUARD.alloc_l3_or_gm(beta2_pow_out->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(beta2_pow_out_ptr); } else { beta2_pow_out_ptr = dev_ctx.template Alloc(beta2_pow_out); } const MT* param_calc_ptr = nullptr; const MT* grad_calc_ptr = nullptr; MT* param_outs_calc_ptr = nullptr; if (std::is_same::value) { MT* param_float = RAII_GUARD.alloc_l3_or_gm(param.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(param_float); MT* grad_float = RAII_GUARD.alloc_l3_or_gm(grad.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(grad_float); MT* param_outs_float = RAII_GUARD.alloc_l3_or_gm(param_outs->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(param_outs_float); int r = xpu::cast(xpu_ctx, param_ptr, param_float, param.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); r = xpu::cast(xpu_ctx, grad_ptr, grad_float, grad.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); param_calc_ptr = param_float; grad_calc_ptr = grad_float; param_outs_calc_ptr = param_outs_float; } else { param_calc_ptr = reinterpret_cast(param_ptr); grad_calc_ptr = reinterpret_cast(grad_ptr); param_outs_calc_ptr = reinterpret_cast(param_outs_ptr); } int r = xpu::lamb( xpu_ctx, grad_calc_ptr, moment1_ptr, moment2_ptr, (multi_precision ? master_param_ptr : param_calc_ptr), beta1_pow_ptr, beta2_pow_ptr, moment1_out_ptr, moment2_out_ptr, (multi_precision ? master_param_outs_ptr : param_outs_calc_ptr), beta1_pow_out_ptr, beta2_pow_out_ptr, beta1, beta2, epsilon, weight_decay, learning_rate_ptr, param.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "lamb"); if (std::is_same::value && multi_precision == false) { int r = xpu::cast( xpu_ctx, param_outs_calc_ptr, param_outs_ptr, param_outs->numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); } if (beta1_pow.place().GetType() == phi::AllocationType::CPU) { // copy beta1_pow_out from xpu to cpu memory_utils::Copy(beta1_pow.place(), dev_ctx.template HostAlloc(beta1_pow_out), dev_ctx.GetPlace(), beta1_pow_out_ptr, sizeof(MT) * beta1_pow_out->numel()); if (beta1_pow_xpu_ptr) { xpu_free(beta1_pow_xpu_ptr); } } if (beta2_pow.place().GetType() == phi::AllocationType::CPU) { // copy beta2_pow_out from xpu to cpu memory_utils::Copy(beta2_pow.place(), dev_ctx.template HostAlloc(beta2_pow_out), dev_ctx.GetPlace(), beta2_pow_out_ptr, sizeof(MT) * beta2_pow_out->numel()); if (beta2_pow_xpu_ptr) { xpu_free(beta2_pow_xpu_ptr); } } } } // namespace phi PD_REGISTER_KERNEL( lamb, XPU, ALL_LAYOUT, phi::LambKernel, float, phi::dtype::float16) { kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); }