/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/optimizers/adam_op.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { template __global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_, MT beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2, MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, int ndim) { MT lr = *lr_; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; lr *= sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); MT g = static_cast(grad[id]); MT mom1 = moment1[id]; MT mom2 = moment2[id]; mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; p -= lr * (mom1 / (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; } } } template __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon, const MT* beta1_pow_, const MT* beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2, MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, int ndim) { MT lr = *lr_; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; lr *= sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; p -= lr * (mom1 / (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; } } } template __global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_, const T* beta2_pow_, T* beta1_pow_out, T* beta2_pow_out) { *beta1_pow_out = beta1 * beta1_pow_[0]; *beta2_pow_out = beta2 * beta2_pow_[0]; } template __global__ void SparseAdamCUDAKernelREG( MT beta1, MT beta2, MT epsilon, const MT beta1_pow, const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_, MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_, T* param_out_, const MT* master_param, MT* master_param_out, const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) { int id = blockIdx.x * blockDim.x + threadIdx.x; MT lr = *lr_; lr *= sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); for (; id < ndim; id += blockDim.x * gridDim.x) { auto row_idx = math::BinarySearch(rows_, row_count, id / row_numel); if (lazy_mode && row_idx < 0) { return; } else { MT mom1 = mom1_[id]; MT mom2 = mom2_[id]; MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 ? static_cast(grad_[row_idx * row_numel + id % row_numel]) : static_cast(0); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; p -= lr * (mom1 / (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); // Write back to global memory mom1_out_[id] = mom1; mom2_out_[id] = mom2; param_out_[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; } } } } template class AdamOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const auto* param_var = ctx.InputVar("Param"); PADDLE_ENFORCE_EQ(param_var->IsType(), true, platform::errors::InvalidArgument( "The Var(%s)'s type should be LoDTensor, " "but the received is %s", ctx.InputNames("Param").front(), framework::ToTypeName(param_var->Type()))); using paddle::framework::LoDTensor; using MPDType = typename details::MPTypeTrait::Type; int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); MPDType epsilon = static_cast(ctx.Attr("epsilon")); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); auto* mom1 = ctx.Input("Moment1"); auto* mom2 = ctx.Input("Moment2"); auto* lr = ctx.Input("LearningRate"); auto* beta1_pow = ctx.Input("Beta1Pow"); auto* beta2_pow = ctx.Input("Beta2Pow"); auto* param_out = ctx.Output("ParamOut"); auto* mom1_out = ctx.Output("Moment1Out"); auto* mom2_out = ctx.Output("Moment2Out"); auto* beta1_pow_out = ctx.Output("Beta1PowOut"); auto* beta2_pow_out = ctx.Output("Beta2PowOut"); MPDType beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { auto* beta1_tensor = ctx.Input("Beta1Tensor"); PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(Beta1Tensor) size must be 1, but get %d", beta1_tensor->numel())); beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); } MPDType beta2 = static_cast(ctx.Attr("beta2")); if (ctx.HasInput("Beta2Tensor")) { auto* beta2_tensor = ctx.Input("Beta2Tensor"); PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(Beta2Tensor) size must be 1, but get %d", beta2_tensor->numel())); beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); } VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() << "beta2_pow.numel() : " << beta2_pow->numel(); VLOG(3) << "param.numel(): " << param->numel(); PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, platform::errors::InvalidArgument( "beta1 pow output size should be 1, but received " "value is:%d.", beta1_pow_out->numel())); PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, platform::errors::InvalidArgument( "beta2 pow output size should be 1, but received " "value is:%d.", beta2_pow_out->numel())); const bool multi_precision = ctx.Attr("multi_precision"); const LoDTensor* master_param = nullptr; LoDTensor* master_param_out = nullptr; if (multi_precision) { bool has_master = ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); PADDLE_ENFORCE_EQ(has_master, true, platform::errors::InvalidArgument( "The Input(MasterParam) and Output(MasterParamOut) " "should not be null when " "the attr `multi_precision` is true")); master_param = ctx.Input("MasterParam"); master_param_out = ctx.Output("MasterParamOut"); } const MPDType* master_in_data = multi_precision ? master_param->data() : nullptr; MPDType* master_out_data = multi_precision ? master_param_out->mutable_data(ctx.GetPlace()) : nullptr; auto& dev_ctx = ctx.template device_context(); if (grad_var->IsType()) { auto* grad = ctx.Input("Grad"); // update param and moment int threads = 512; int blocks = (param->numel() + threads - 1) / threads; if (beta1_pow->place() == platform::CPUPlace() && beta2_pow->place() == platform::CPUPlace()) { // Compute with betapow in REG AdamKernelREG<<>>( beta1, beta2, epsilon, *beta1_pow->data(), *beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad->data(), param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, param->numel()); // Cpu update beta1_pow_out->mutable_data(platform::CPUPlace())[0] = beta1 * beta1_pow->data()[0]; beta2_pow_out->mutable_data(platform::CPUPlace())[0] = beta2 * beta2_pow->data()[0]; } else { AdamKernelMEM<<>>( beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad->data(), param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, param->numel()); // Update with gpu UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( beta1, beta2, beta1_pow->data(), beta2_pow->data(), beta1_pow_out->mutable_data(ctx.GetPlace()), beta2_pow_out->mutable_data(ctx.GetPlace())); } } else if (grad_var->IsType()) { auto* grad = ctx.Input("Grad"); if (grad->rows().size() == 0) { VLOG(3) << "grad row size is 0!!"; return; } std::vector cpu_rows(grad->rows().begin(), grad->rows().end()); bool is_strict_sorted = true; for (size_t i = 1; i < cpu_rows.size(); ++i) { if (cpu_rows[i - 1] >= cpu_rows[i]) { is_strict_sorted = false; break; } } framework::SelectedRows tmp_grad_merge; const framework::SelectedRows* grad_merge_ptr; if (is_strict_sorted) { grad_merge_ptr = grad; } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor scatter::MergeAdd merge_func; merge_func(ctx.template device_context(), *grad, &tmp_grad_merge, true); grad_merge_ptr = &tmp_grad_merge; } auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); if (beta1_pow->place() == platform::CPUPlace() && beta2_pow->place() == platform::CPUPlace()) { int threads = 512; int ndim = param->numel(); int blocks = (ndim + threads - 1) / threads; SparseAdamCUDAKernelREG< T, MPDType><<>>( beta1, beta2, epsilon, *beta1_pow->data(), *beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad_data, param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, rows, row_numel, grad_merge.rows().size(), lazy_mode, ndim); // Update with cpu beta1_pow_out->mutable_data(platform::CPUPlace())[0] = beta1 * beta1_pow->data()[0]; beta2_pow_out->mutable_data(platform::CPUPlace())[0] = beta2 * beta2_pow->data()[0]; } else { SparseAdamFunctor functor( beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), lr->data(), grad_data, param->data(), param_out->mutable_data(ctx.GetPlace()), master_in_data, master_out_data, rows, row_numel, grad_merge.rows().size(), lazy_mode); // FIXME(minqiyang): remove BinarySearch in GPU later platform::ForRange for_range( static_cast( ctx.device_context()), param->numel()); for_range(functor); // update beta1 and beta2 UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( beta1, beta2, beta1_pow->data(), beta2_pow->data(), beta1_pow_out->mutable_data(ctx.GetPlace()), beta2_pow_out->mutable_data(ctx.GetPlace())); } } else { PADDLE_THROW(platform::errors::InvalidArgument( "Variable type not supported by adam_op")); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(adam, ops::AdamOpCUDAKernel, ops::AdamOpCUDAKernel, ops::AdamOpCUDAKernel);