diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f5f5f449ebff170154f24c93162d10d753894267..e9c40998483d4e157a38fa56e74f1a8c33158df2 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -516,6 +516,12 @@ paddle.fluid.optimizer.DGCMomentumOptimizer.apply_optimize (ArgSpec(args=['self' paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.optimizer.LambOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) +paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) +paddle.fluid.optimizer.LambOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.optimizer.LambOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.backward.append_backward (ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '08a5dd9f6f376ff3d55e0b1d92115cbd')) paddle.fluid.regularizer.L1DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.regularizer.L2DecayRegularizer.__init__ (ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 54e0f5146dab3e19713d19e15c6c81868179b319..dd347aa0afebe5c75e7f3b574083783b4454fd20 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -18,67 +18,64 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -class AdamOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(Param) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(Grad) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Moment1"), - "Input(Moment1) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Moment2"), - "Input(Moment2) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LearningRate"), - "Input(LearningRate) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), - "Input(Beta1Pow) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"), - "Input(Beta2Pow) of AdamOp should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), - "Output(ParamOut) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"), - "Output(Moment1Out) of AdamOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"), - "Output(Moment2Out) of AdamOp should not be null."); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, - "Learning rate should have 1 dimension"); - auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); - PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, - "Beta1 power accumulator should have 1 dimension"); - auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); - PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1, - "Beta2 power accumulator should have 1 dimension"); - - auto param_dims = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Grad"), - "Param and Grad input of AdamOp should have same dimension"); - } - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Moment1"), - "Param and Moment1 input of AdamOp should have same dimension"); - PADDLE_ENFORCE_EQ( - param_dims, ctx->GetInputDim("Moment2"), - "Param and Moment2 input of AdamOp should have same dimension"); - ctx->SetOutputDim("ParamOut", param_dims); - ctx->SetOutputDim("Moment1Out", param_dims); - ctx->SetOutputDim("Moment2Out", param_dims); - } - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("Param")->type(); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); +void AdamOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment1"), + "Input(Moment1) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment2"), + "Input(Moment2) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), + "Input(Beta1Pow) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"), + "Input(Beta2Pow) of AdamOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"), + "Output(Moment1Out) of AdamOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"), + "Output(Moment2Out) of AdamOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 dimension"); + auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); + PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1, + "Beta1 power accumulator should have 1 dimension"); + auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); + PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1, + "Beta2 power accumulator should have 1 dimension"); + + auto param_dims = ctx->GetInputDim("Param"); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Grad"), + "Param and Grad input of AdamOp should have same dimension"); } -}; + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment1"), + "Param and Moment1 input of AdamOp should have same dimension"); + PADDLE_ENFORCE_EQ( + param_dims, ctx->GetInputDim("Moment2"), + "Param and Moment2 input of AdamOp should have same dimension"); + + ctx->SetOutputDim("ParamOut", param_dims); + ctx->SetOutputDim("Moment1Out", param_dims); + ctx->SetOutputDim("Moment2Out", param_dims); +} + +framework::OpKernelType AdamOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + auto input_data_type = ctx.Input("Param")->type(); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); +} class AdamOpMaker : public framework::OpProtoAndCheckerMaker { public: diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 6262ef0c2d3802bca574ba1312e7cf4a720403ef..1cc34f11d09e9ec1868249f20fcc1b189efb0589 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -29,6 +29,15 @@ namespace operators { namespace scatter = paddle::operators::math::scatter; +class AdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + struct GPUAdam; struct CPUAdam; diff --git a/paddle/fluid/operators/optimizers/lamb_op.cc b/paddle/fluid/operators/optimizers/lamb_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf407b01d31065b5042ebf5f787d9ddd41e7bced --- /dev/null +++ b/paddle/fluid/operators/optimizers/lamb_op.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/optimizers/lamb_op.h" +#include "paddle/fluid/operators/optimizers/adam_op.h" + +namespace paddle { +namespace operators { + +class LambOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", + "(LoDTensor, default LoDTensor) " + "Input parameter that has to be updated."); + AddInput("Grad", + "(LoDTensor, default LoDTensor) " + "Input gradient of the parameter."); + AddInput("LearningRate", "(Tensor) Learning rate."); + AddInput("Moment1", "(Tensor) Input first moment."); + AddInput("Moment2", "(Tensor) Input second moment."); + AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator."); + AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator."); + + AddOutput("ParamOut", "(Tensor) Output parameter."); + AddOutput("Moment1Out", "(Tensor) Output first moment."); + AddOutput("Moment2Out", "(Tensor) Output second moment."); + AddAttr("weight_decay", "(float) Weight decay rate."); + AddAttr("beta1", + "(float, default 0.9) The exponential decay rate for the " + "1st moment estimates.") + .SetDefault(0.9); + AddAttr("beta2", + "(float, default 0.999) The exponential decay rate for the " + "2nd moment estimates.") + .SetDefault(0.999); + AddAttr("epsilon", + "(float, default 1.0e-6) " + "Constant for numerical stability.") + .SetDefault(1.0e-6f); + + AddComment(R"DOC( +LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer. + +LAMB Optimizer is designed to scale up the batch size of training without losing +accuracy, which supports adaptive element-wise updating and accurate layer-wise +correction. For more information, please refer to https://arxiv.org/abs/1904.00962. + +The updating of parameters follows: + +$$ +m_t^l &= \beta_1 m_{t - 1}^l + (1 - \beta_1)g_t^l \\ + +v_t^l &= \beta_2 v_{t - 1}^l + (1 - \beta_2)g_t^l \odot g_t^l \\ + +\widehat{m}_t^l &= m_t^l/(1 - \beta_1^t) \\ + +\widehat{v}_t^l &= v_t^l/(1 - \beta_2^t) \\ + +r_1 &= \left \| w_{t-1}^l \right \|_2 \\ + +r_2 &= \left \| \frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l \right \|_2 \\ + +r &= r_1 / r_2 \\ + +\eta^l &= r \times \eta \\ + +w_t^l &= w_{t-1}^l -\eta ^l \times (\frac{\widehat{m}_t^l}{\sqrt{\widehat{v}_t^l+\epsilon}} + \lambda w_{t-1}^l) +$$ + +where $m$ is the 1st moment, and $v$ the 2nd moment, $\eta$ the +learning rate, $\lambda$ the weight decay rate. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::AdamOp, ops::LambOpMaker); +REGISTER_OP_CPU_KERNEL( + lamb, ops::LambOpKernel, + ops::LambOpKernel); diff --git a/paddle/fluid/operators/optimizers/lamb_op.cu b/paddle/fluid/operators/optimizers/lamb_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9ffb62926a4fffd95ca014947282a7a32e92e4b8 --- /dev/null +++ b/paddle/fluid/operators/optimizers/lamb_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/optimizers/lamb_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + lamb, ops::LambOpKernel, + ops::LambOpKernel); diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4fbe486cb1d68f9e569db5774e917e6b5ec9f8bb --- /dev/null +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -0,0 +1,314 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include // for sqrt in CPU and CUDA +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +namespace scatter = paddle::operators::math::scatter; + +template +struct LambMomentUpdateFunctor { + T weight_decay_; + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* grad_; + const T* param_; + T* trust_ratio_div_; + + LambMomentUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, + const T* beta1_pow, const T* beta2_pow, const T* mom1, + T* mom1_out, const T* mom2, T* mom2_out, + const T* grad, const T* param, T* trust_ratio_div) + : weight_decay_(weight_decay), + beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + grad_(grad), + param_(param), + trust_ratio_div_(trust_ratio_div) {} + + inline HOSTDEVICE void operator()(size_t i) const { + T g = grad_[i]; + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + T p = param_[i]; + + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + + T mom1_h = mom1 / (1 - beta1_pow); + T mom2_h = mom2 / (1 - beta2_pow); + + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p; + } +}; + +template +struct SparseLambMomentUpdateFunctor { + T weight_decay_; + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* grad_; + const T* param_; + T* trust_ratio_div_; + + const int64_t* rows_; + int64_t row_numel_; + int64_t row_count_; + + SparseLambMomentUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, + const T* beta1_pow, const T* beta2_pow, + const T* mom1, T* mom1_out, const T* mom2, + T* mom2_out, const T* grad, const T* param, + T* trust_ratio_div, const int64_t* rows, + int64_t row_numel, int64_t row_count) + : weight_decay_(weight_decay), + beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + grad_(grad), + param_(param), + trust_ratio_div_(trust_ratio_div), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + inline HOSTDEVICE void update(size_t i, T g) const { + // The following code is same as dense + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + T p = param_[i]; + + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + + T mom1_h = mom1 / (1 - beta1_pow); + T mom2_h = mom2 / (1 - beta2_pow); + + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + trust_ratio_div_[i] = mom1_h / sqrt(mom2_h + epsilon_) + weight_decay_ * p; + } + + inline HOSTDEVICE void operator()(size_t i) const { + auto row_idx = + math::BinarySearch(rows_, row_count_, i / row_numel_); + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + update(i, g); + } +}; + +template +struct LambParamUpateFunctor { + const T* lr_; + const T* param_; + const T* param_norm_; + const T* trust_ratio_div_; + const T* trust_ratio_div_norm_; + T* param_out_; + + LambParamUpateFunctor(const T* lr, const T* param, const T* param_norm, + const T* trust_ratio_div, const T* trust_ratio_div_norm, + T* param_out) + : lr_(lr), + param_(param), + param_norm_(param_norm), + trust_ratio_div_(trust_ratio_div), + trust_ratio_div_norm_(trust_ratio_div_norm), + param_out_(param_out) {} + + inline HOSTDEVICE void operator()(size_t i) const { + T lr = *lr_; + T p_norm = *param_norm_; + T tr_div_norm = *trust_ratio_div_norm_; + + lr *= p_norm / tr_div_norm; + param_out_[i] = param_[i] - lr * trust_ratio_div_[i]; + } +}; + +template +class LambOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE(param_var->IsType(), + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.Inputs("Param").front(), + framework::ToTypeName(param_var->Type())); + + using paddle::framework::LoDTensor; + using paddle::operators::detail::Ref; + + T weight_decay = static_cast(ctx.Attr("weight_decay")); + T beta1 = static_cast(ctx.Attr("beta1")); + T beta2 = static_cast(ctx.Attr("beta2")); + T epsilon = static_cast(ctx.Attr("epsilon")); + auto& param = Ref(ctx.Input("Param"), "Must set Param."); + auto* grad_var = ctx.InputVar("Grad"); + auto& mom1 = Ref(ctx.Input("Moment1"), "Must set Moment1."); + auto& mom2 = Ref(ctx.Input("Moment2"), "Must set Moment2."); + auto& lr = + Ref(ctx.Input("LearningRate"), "Must set LearningRate."); + + auto& beta1_pow = + Ref(ctx.Input("Beta1Pow"), "Must set Beta1Pow."); + auto& beta2_pow = + Ref(ctx.Input("Beta2Pow"), "Must set Beta2Pow."); + + auto& param_out = + Ref(ctx.Output("ParamOut"), "Must set ParamOut."); + auto& mom1_out = + Ref(ctx.Output("Moment1Out"), "Must set Moment1Out."); + auto& mom2_out = + Ref(ctx.Output("Moment2Out"), "Must set Moment1Out."); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, param.numel()); + framework::Tensor trust_ratio_div = + ctx.AllocateTmpTensor(param.dims(), dev_ctx); + + // Update moments + if (grad_var->IsType()) { + auto& grad = Ref(ctx.Input("Grad"), "Must set Grad."); + + LambMomentUpdateFunctor moment_update_functor( + weight_decay, beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + grad.template data(), param.template data(), + trust_ratio_div.template data()); + for_range(moment_update_functor); + } else if (grad_var->IsType()) { + auto& grad = + Ref(ctx.Input("Grad"), "Must set Grad."); + if (grad.rows().size() == 0) { + VLOG(3) << "grad row size is 0!!"; + return; + } + + std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); + bool is_strict_sorted = true; + for (size_t i = 1; i < cpu_rows.size(); ++i) { + if (cpu_rows[i - 1] >= cpu_rows[i]) { + is_strict_sorted = false; + break; + } + } + + framework::SelectedRows tmp_grad_merge; + const framework::SelectedRows* grad_merge_ptr; + if (is_strict_sorted) { + grad_merge_ptr = &grad; + } else { + // merge duplicated rows if any. + // The rows of grad_merge have been sorted inside MergeAdd functor + scatter::MergeAdd merge_func; + merge_func(dev_ctx, grad, &tmp_grad_merge, true); + grad_merge_ptr = &tmp_grad_merge; + } + + auto& grad_merge = *grad_merge_ptr; + auto& grad_tensor = grad_merge.value(); + const T* grad_data = grad_tensor.template data(); + const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + + SparseLambMomentUpdateFunctor moment_update_functor( + weight_decay, beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), grad_data, + param.template data(), trust_ratio_div.template data(), rows, + row_numel, grad_merge.rows().size()); + for_range(moment_update_functor); + } else { + PADDLE_THROW("Variable type not supported by lamb_op."); + } + + // Update parameter + framework::Tensor p_norm_t = + ctx.AllocateTmpTensor({1}, dev_ctx); + framework::Tensor trust_ratio_div_norm_t = + ctx.AllocateTmpTensor({1}, dev_ctx); + auto p_norm = framework::EigenScalar::From(p_norm_t); + auto trust_ratio_div_norm = + framework::EigenScalar::From(trust_ratio_div_norm_t); + + auto p = framework::EigenVector::Flatten(param); + auto t = framework::EigenVector::Flatten(trust_ratio_div); + + auto* place = dev_ctx.eigen_device(); + p_norm.device(*place) = p.square().sum().sqrt(); + trust_ratio_div_norm.device(*place) = t.square().sum().sqrt(); + + LambParamUpateFunctor param_update_functor( + lr.template data(), param.template data(), + p_norm_t.template data(), trust_ratio_div.template data(), + trust_ratio_div_norm_t.template data(), + param_out.template mutable_data(ctx.GetPlace())); + for_range(param_update_functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py index 29acb589e171c85bccd52e8b005ea6a3b90397a7..a2e700803dcf3a2da5b7f1e15b68fb8b274a939a 100644 --- a/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py +++ b/python/paddle/fluid/contrib/tests/test_image_classification_fp16.py @@ -132,7 +132,7 @@ def train(net_type, use_cuda, save_dirname, is_local): # Test program test_program = train_program.clone(for_test=True) - optimizer = fluid.optimizer.Adam(learning_rate=0.001) + optimizer = fluid.optimizer.Lamb(learning_rate=0.001) mp_optimizer = fluid.contrib.mixed_precision.decorate( optimizer=optimizer, diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 4f21ace202120031db9f331713d56f7bdacc0ba6..f494ab92664b3df506f31418bb81e8370111087c 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -42,7 +42,7 @@ __all__ = [ 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum', - 'LarsMomentumOptimizer', 'DGCMomentumOptimizer' + 'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer' ] @@ -1851,6 +1851,133 @@ class FtrlOptimizer(Optimizer): return ftrl_op +class LambOptimizer(AdamOptimizer): + """ + LAMB (Layer-wise Adaptive Moments optimizer for Batching training) Optimizer. + + LAMB Optimizer is designed to scale up the batch size of training without losing + accuracy, which supports adaptive element-wise updating and accurate layer-wise + correction. For more information, please refer to `Reducing BERT Pre-Training + Time from 3 Days to 76 Minutes `_ . + + The updating of parameters follows: + + .. math:: + + m_t^l & = \\beta_1 m_{t - 1}^l + (1 - \\beta_1)g_t^l + + v_t^l & = \\beta_2 v_{t - 1}^l + (1 - \\beta_2)g_t^l \odot g_t^l + + \\widehat{m}_t^l & = m_t^l/(1 - \\beta_1^t) + + \\widehat{v}_t^l & = v_t^l/(1 - \\beta_2^t) + + r_1 & = \\left \| w_{t-1}^l \\right \|_2 + + r_2 & = \\left \| \\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l \\right \|_2 + + r & = r_1 / r_2 + + \\eta^l & = r \\times \\eta + + w_t^l & = w_{t-1}^l -\\eta ^l \\times (\\frac{\\widehat{m}_t^l}{\\sqrt{\\widehat{v}_t^l+\\epsilon}} + \\lambda w_{t-1}^l) + + + where :math:`m` is the 1st moment, and :math:`v` the 2nd moment, :math:`\\eta` the + learning rate, :math:`\\lambda` the LAMB weight decay rate. + + Args: + learning_rate (float|Variable): the learning rate used to update parameters. \ + Can be a float value or a Variable with one \ + float value as data element. + lamb_weight_decay (float): The LAMB weight decay rate. + beta1 (float): The exponential decay rate for the 1st moment estimates. + beta2 (float): The exponential decay rate for the 2nd moment estimates. + epsilon (float): A small float value for numerical stability. + regularization: A Regularizer, such as + fluid.regularizer.L1DecayRegularizer. + name (str|None): An optional name prefix. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + data = fluid.layers.data(name='x', shape=[5], dtype='float32') + hidden = fluid.layers.fc(input=data, size=10) + cost = fluid.layers.mean(hidden) + + optimizer = fluid.optimizer.Lamb(learning_rate=0.002) + optimizer.minimize(cost) + """ + _moment1_acc_str = "moment1" + _moment2_acc_str = "moment2" + _beta1_pow_acc_str = "beta1_pow_acc" + _beta2_pow_acc_str = "beta2_pow_acc" + + def __init__(self, + learning_rate=0.001, + lamb_weight_decay=0.01, + beta1=0.9, + beta2=0.999, + epsilon=1e-6, + regularization=None, + name=None): + assert learning_rate is not None + assert lamb_weight_decay is not None + assert beta1 is not None + assert beta2 is not None + assert epsilon is not None + super(LambOptimizer, self).__init__( + learning_rate=learning_rate, + regularization=regularization, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + name=name) + self.type = "lamb" + self._weight_decay = lamb_weight_decay + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + moment1 = self._get_accumulator(self._moment1_acc_str, + param_and_grad[0]) + moment2 = self._get_accumulator(self._moment2_acc_str, + param_and_grad[0]) + beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, + param_and_grad[0]) + beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, + param_and_grad[0]) + + # create the lamb optimize op + lamb_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": self._create_param_lr(param_and_grad), + "Moment1": moment1, + "Moment2": moment2, + "Beta1Pow": beta1_pow_acc, + "Beta2Pow": beta2_pow_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "Moment1Out": moment1, + "Moment2Out": moment2 + }, + attrs={ + "beta1": self._beta1, + "beta2": self._beta2, + "epsilon": self._epsilon, + "weight_decay": self._weight_decay + }, + stop_gradient=True) + + return lamb_op + + # We short the class name, since users will use the optimizer with the package # name. The sample code: # @@ -1869,6 +1996,7 @@ Adadelta = AdadeltaOptimizer RMSProp = RMSPropOptimizer Ftrl = FtrlOptimizer LarsMomentum = LarsMomentumOptimizer +Lamb = LambOptimizer class ModelAverage(Optimizer): diff --git a/python/paddle/fluid/tests/unittests/test_lamb_op.py b/python/paddle/fluid/tests/unittests/test_lamb_op.py new file mode 100644 index 0000000000000000000000000000000000000000..cac9e2643d69ec70039382f35937dc04c0e76f4b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lamb_op.py @@ -0,0 +1,302 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import core +from paddle.fluid.op import Operator + + +class TestLambOp1(OpTest): + def set_attrs(self): + self.attrs = { + 'epsilon': 1e-4, + 'beta1': 0.78, + 'beta2': 0.836, + 'weight_decay': 0.01 + } + + def setUp(self): + '''Test Lamb Op with supplied attributes + ''' + self.op_type = "lamb" + param = np.random.uniform(-1, 1, (102, 105)).astype("float32") + grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + moment2 = np.random.random((102, 105)).astype("float32") + + learning_rate = 0.001 + self.set_attrs() + beta1_pow = self.attrs['beta1']**10 + beta2_pow = self.attrs['beta2']**10 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment1': moment1, + 'Moment2': moment2, + 'LearningRate': np.array([learning_rate]).astype("float32"), + 'Beta1Pow': np.array([beta1_pow]).astype("float32"), + 'Beta2Pow': np.array([beta2_pow]).astype("float32") + } + + + param_out, moment1_out, \ + moment2_out = lamb_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out + } + + def test_check_output(self): + self.check_output() + + +class TestLambOp2(TestLambOp1): + def set_attrs(self): + self.attrs = { + 'epsilon': 1e-8, + 'beta1': 0.9, + 'beta2': 0.999, + 'weight_decay': 0.01 + } + + +class TestLambOpMultipleSteps(TestLambOp1): + def set_attrs(self): + self.attrs = { + 'epsilon': 1e-8, + 'beta1': 0.9, + 'beta2': 0.999, + 'weight_decay': 0.01 + } + self.num_steps = 10 + + def test_check_output(self): + for _ in range(self.num_steps): + param_out, moment1_out, \ + moment2_out = lamb_step(self.inputs, self.attrs) + + self.outputs = { + 'Moment1Out': moment1_out, + 'Moment2Out': moment2_out, + 'ParamOut': param_out + } + + # Verify output for this step + self.check_output() + + # Output of this step becomes input for next step + self.inputs['Param'] = param_out + self.inputs['Moment1'] = moment1_out + self.inputs['Moment2'] = moment2_out + + # Update powers of Beta1 and Beta2 for next time step + self.inputs['Beta1Pow'] *= self.attrs['beta1'] + self.inputs['Beta2Pow'] *= self.attrs['beta1'] + + # Randomize gradient for next step + self.inputs['Grad'] = np.random.uniform( + -1, 1, (102, 105)).astype("float32") + + +def lamb_step(inputs, attributes): + ''' + Simulate one step of the lamb optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + weight_decay = attributes['weight_decay'] + + moment1_out = beta1 * moment1 + (1 - beta1) * grad + moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + + mom1_tmp = moment1_out / (1 - beta1_pow) + mom2_tmp = moment2_out / (1 - beta2_pow) + + r_1 = np.linalg.norm(param) + r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) + weight_decay * + param) + lr_t = lr * r_1 / r_2 + + param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) + + weight_decay * param) + return param_out, moment1_out, moment2_out + + +def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): + ''' + Simulate one step of the lamb optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + # grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + weight_decay = attributes['weight_decay'] + + moment1_out = np.zeros(shape=[height, row_numel]) + moment2_out = np.zeros(shape=[height, row_numel]) + param_out = np.zeros(shape=[height, row_numel]) + + def update_mom(row_id, update_value): + moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 + ) * update_value + moment2_out[row_id] = beta2 * moment2[row_id] + ( + 1 - beta2) * np.square(update_value) + + moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 + ) * update_value + moment2_out[row_id] = beta2 * moment2[row_id] + ( + 1 - beta2) * np.square(update_value) + + def update_param(): + mom1_tmp = moment1_out / (1 - beta1_pow) + mom2_tmp = moment2_out / (1 - beta2_pow) + + r_1 = np.linalg.norm(param) + r_2 = np.linalg.norm(mom1_tmp / np.sqrt(mom2_tmp + epsilon) + + weight_decay * param) + lr_t = lr * r_1 / r_2 + + param_out = param - lr_t * (mom1_tmp / np.sqrt(mom2_tmp + epsilon) + + weight_decay * param) + + for row_id in range(param_out.shape[0]): + update_value = np.zeros(np_grad[0].shape).astype("float32") + if row_id in rows: + update_value = np_grad[rows.index(row_id)] + update_mom(row_id, update_value) + + update_param() + + return param_out, moment1_out, moment2_out + + +class TestSparseLambOp(unittest.TestCase): + def setup(self, scope, place): + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + + height = 10 + rows = [0, 4, 7] + self.rows = rows + row_numel = 12 + self.row_numel = row_numel + self.dense_inputs = { + "Param": np.full((height, row_numel), 5.0).astype("float32"), + "Moment1": np.full((height, row_numel), 5.0).astype("float32"), + "Moment2": np.full((height, row_numel), 5.0).astype("float32"), + 'Beta1Pow': np.array([beta1**10]).astype("float32"), + 'Beta2Pow': np.array([beta2**10]).astype("float32"), + "LearningRate": np.full((1), 2.0).astype("float32") + } + self.init_output = np.full((height, row_numel), 0.0).astype("float32") + self.attrs = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + 'weight_decay': 0.05 + } + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + self.sparse_inputs = ["Grad"] + + param_out, mom1, mom2 = lamb_step_sparse( + self.dense_inputs, self.attrs, height, rows, row_numel, np_array) + self.outputs = { + "ParamOut": param_out, + "Moment1Out": mom1, + "Moment2Out": mom2 + } + + def check_with_place(self, place): + scope = core.Scope() + self.setup(scope, place) + + op_args = dict() + for key, np_array in self.dense_inputs.items(): + var = scope.var(key).get_tensor() + var.set(np_array, place) + op_args[key] = key + for s in self.sparse_inputs: + op_args[s] = s + for s in self.outputs: + var = scope.var(s).get_tensor() + var.set(self.init_output, place) + op_args[s] = s + for k in self.attrs: + op_args[k] = self.attrs[k] + + # create and run sgd operator + lamb_op = Operator("lamb", **op_args) + lamb_op.run(scope, place) + + for key, np_array in self.outputs.items(): + out_var = scope.var(key).get_tensor() + actual = np.array(out_var) + actual = actual.reshape([actual.size]) + np_array = np_array.reshape([np_array.size]) + + for i in range(np_array.size): + self.assertLess((actual[i] - np_array[i]), 0.00001) + + def test_sparse_lamb(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + +if __name__ == "__main__": + unittest.main()