From 91cf918fd0a0a6fe0a3e7f60b8478d46ca931fa8 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Tue, 14 Sep 2021 21:41:17 +0800 Subject: [PATCH] add layerwise learning rate for adamw (#35569) * add layerwise learning rate for adamw * fix format * add unitest * add NotImplementedError * add gpu unitest * update gpuplace --- paddle/fluid/operators/optimizers/adam_op.cc | 4 + paddle/fluid/operators/optimizers/adamw_op.cu | 68 +++++++-------- paddle/fluid/operators/optimizers/adamw_op.h | 25 +++--- .../fluid/tests/unittests/test_adamw_op.py | 87 +++++++++++++++++++ python/paddle/optimizer/adamw.py | 19 +++- 5 files changed, 158 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index d4355c89f31..0c2fedad739 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -236,6 +236,10 @@ class AdamWOpMaker : public AdamOpMaker { public: void Make() { AdamOpMaker::Make(); + AddAttr("lr_ratio", + "(float, default 1.0) " + "layerwise learning rate decay") + .SetDefault(1.0f); AddAttr("coeff", "(float, default 0.01) " "coeff of the weight decay") diff --git a/paddle/fluid/operators/optimizers/adamw_op.cu b/paddle/fluid/operators/optimizers/adamw_op.cu index af2bb93e06d..49b7fe771be 100644 --- a/paddle/fluid/operators/optimizers/adamw_op.cu +++ b/paddle/fluid/operators/optimizers/adamw_op.cu @@ -20,17 +20,17 @@ namespace operators { template __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, - MT beta1_pow_, MT beta2_pow_, const MT* moment1, - MT* moment1_out, const MT* moment2, - MT* moment2_out, const MT* lr_, const T* grad, - const T* param, T* param_out, - const MT* master_param, MT* master_param_out, - int ndim) { - MT lr = *lr_; + MT lr_ratio, MT beta1_pow_, MT beta2_pow_, + const MT* moment1, MT* moment1_out, + const MT* moment2, MT* moment2_out, + const MT* lr_, const T* grad, const T* param, + T* param_out, const MT* master_param, + MT* master_param_out, int ndim) { + MT lr = *lr_ * lr_ratio; + MT lr_orig = lr; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; - MT wd = static_cast(1.0) - coeff * lr; lr *= sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); @@ -43,9 +43,9 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, MT mom2 = moment2[id]; mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p = wd * p - - lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + p -= lr_orig * coeff * p; + p -= lr * (mom1 / + (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; @@ -57,18 +57,16 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, } template -__global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff, - const MT* beta1_pow_, const MT* beta2_pow_, - const MT* moment1, MT* moment1_out, - const MT* moment2, MT* moment2_out, - const MT* lr_, const T* grad, const T* param, - T* param_out, const MT* master_param, - MT* master_param_out, int ndim) { - MT lr = *lr_; +__global__ void AdamWKernelMEM( + MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT* beta1_pow_, + const MT* beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2, + MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out, + const MT* master_param, MT* master_param_out, int ndim) { + MT lr = *lr_ * lr_ratio; + MT lr_orig = lr; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; - MT wd = static_cast(1.0) - coeff * lr; lr *= sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); @@ -81,9 +79,9 @@ __global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff, MT mom2 = static_cast(moment2[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p = wd * p - - lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + p -= lr_orig * coeff * p; + p -= lr * (mom1 / + (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; @@ -103,16 +101,16 @@ __global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_, template __global__ void SparseAdamWCUDAKernelREG( - MT beta1, MT beta2, MT epsilon, MT coeff, const MT beta1_pow, + MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT beta1_pow, const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_, MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_, T* param_out_, const MT* master_param, MT* master_param_out, const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) { int id = blockIdx.x * blockDim.x + threadIdx.x; - MT lr = *lr_; + MT lr = *lr_ * lr_ratio; + MT lr_orig = lr; - MT wd = static_cast(1.0) - coeff * lr; lr *= sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); @@ -130,9 +128,9 @@ __global__ void SparseAdamWCUDAKernelREG( : static_cast(0); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p = wd * p - - lr * (mom1 / (sqrt(mom2) + - epsilon * sqrt(static_cast(1.0) - beta2_pow))); + p -= lr_orig * coeff * p; + p -= lr * (mom1 / (sqrt(mom2) + + epsilon * sqrt(static_cast(1.0) - beta2_pow))); // Write back to global memory mom1_out_[id] = mom1; @@ -165,7 +163,9 @@ class AdamWOpCUDAKernel : public framework::OpKernel { bool lazy_mode = ctx.Attr("lazy_mode"); bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; - float coeff = ctx.Attr("coeff"); + + MPDType coeff = static_cast(ctx.Attr("coeff")); + MPDType lr_ratio = static_cast(ctx.Attr("lr_ratio")); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); @@ -301,7 +301,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel { beta2_pow->place() == platform::CPUPlace()) { // Compute with betapow in REG AdamWKernelREG<<>>( - beta1, beta2, epsilon, coeff, *beta1_pow->data(), + beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data(), *beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), @@ -318,7 +318,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel { } } else { AdamWKernelMEM<<>>( - beta1, beta2, epsilon, coeff, beta1_pow->data(), + beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data(), beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), @@ -377,7 +377,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel { SparseAdamWCUDAKernelREG< T, MPDType><<>>( - beta1, beta2, epsilon, coeff, *beta1_pow->data(), + beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data(), *beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), @@ -395,7 +395,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel { } } else { SparseAdamWFunctor functor( - beta1, beta2, epsilon, coeff, beta1_pow->data(), + beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data(), beta2_pow->data(), mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), mom2->data(), diff --git a/paddle/fluid/operators/optimizers/adamw_op.h b/paddle/fluid/operators/optimizers/adamw_op.h index b6dce0a68c6..1904db4f7d6 100644 --- a/paddle/fluid/operators/optimizers/adamw_op.h +++ b/paddle/fluid/operators/optimizers/adamw_op.h @@ -32,12 +32,13 @@ template class AdamWFunctor { private: const T coeff_; + const T lr_ratio_; const T* lr_; T* param_; public: - AdamWFunctor(const T coeff, const T* lr, T* param) - : coeff_(coeff), lr_(lr), param_(param) {} + AdamWFunctor(const T coeff, const T lr_ratio, const T* lr, T* param) + : coeff_(coeff), lr_ratio_(lr_ratio), lr_(lr), param_(param) {} inline HOSTDEVICE void operator()(size_t numel) const { Eigen::Map> param{ @@ -46,7 +47,7 @@ class AdamWFunctor { T lr = *lr_; // Calculation - param = param * (1 - lr * coeff_); + param -= lr * lr_ratio_ * coeff_ * param; } }; @@ -60,6 +61,7 @@ class SparseAdamWFunctor { MT beta2_; MT epsilon_; MT coeff_; + MT lr_ratio_; const MT* beta1_pow_; const MT* beta2_pow_; @@ -80,7 +82,7 @@ class SparseAdamWFunctor { bool lazy_mode_; public: - SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff, + SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT* beta1_pow, const MT* beta2_pow, const MT* mom1, MT* mom1_out, const MT* mom2, MT* mom2_out, const MT* lr, const T* grad, const T* param, T* param_out, @@ -91,6 +93,7 @@ class SparseAdamWFunctor { beta2_(beta2), epsilon_(epsilon), coeff_(coeff), + lr_ratio_(lr_ratio), beta1_pow_(beta1_pow), beta2_pow_(beta2_pow), moment1_(mom1), @@ -112,21 +115,21 @@ class SparseAdamWFunctor { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; - MT lr = *lr_; + MT lr = *lr_ * lr_ratio_; + MT lr_orig = lr; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; MT p = master_param_ ? master_param_[i] : static_cast(param_[i]); // Calculation - MT wd = static_cast(1.0) - coeff_ * lr; lr *= sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; - p = wd * p - - lr * (mom1 / - (sqrt(mom2) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); + p -= lr_orig * coeff_ * p; + p -= lr * (mom1 / (sqrt(mom2) + + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); // Write back to global memory moment1_out_[i] = mom1; @@ -187,6 +190,7 @@ class AdamWOpKernel : public AdamOpKernel { } T coeff = static_cast(ctx.Attr("coeff")); + T lr_ratio = static_cast(ctx.Attr("lr_ratio")); auto* lr = ctx.Input("LearningRate"); LoDTensor* param; @@ -198,7 +202,8 @@ class AdamWOpKernel : public AdamOpKernel { param = const_cast(ctx.Input("Param")); } - AdamWFunctor functor(coeff, lr->data(), param->data()); + AdamWFunctor functor(coeff, lr_ratio, lr->data(), + param->data()); functor(param->numel()); AdamOpKernel::Compute(ctx); diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index d99e15b2128..2a5dc76c6bb 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -16,6 +16,7 @@ import unittest import paddle import numpy as np import paddle.fluid as fluid +from functools import partial class TestAdamWOp(unittest.TestCase): @@ -148,5 +149,91 @@ class TestAdamWOpGroupWithLR(TestAdamWOp): adam.clear_gradients() +def simple_lr_setting(param, decay_rate, n_layers): + if "fc_0" in param.name or "linear_1" in param.name: + depth = int(param.name.split("_")[2]) + 1 + elif "fc_1" in param.name or "linear_2" in param.name: + depth = int(param.name.split("_")[2]) + 2 + else: + depth = 0 + + return decay_rate**(n_layers + 2 - depth) + + +class TestAdamWOpLayerwiseLR(TestAdamWOp): + def test_adamw_op_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear1 = paddle.nn.Linear(13, 8) + linear2 = paddle.nn.Linear(8, 5) + + simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2) + + adam = paddle.optimizer.AdamW( + learning_rate=0.01, + parameters=[{ + 'params': linear1.parameters() + }, { + 'params': linear2.parameters(), + }], + apply_decay_param_fun=lambda name: True, + weight_decay=0.01, + lr_ratio=simple_lr_fun) + + for _ in range(2): + a1 = linear1(a) + out = linear2(a1) + out.backward() + adam.step() + adam.clear_gradients() + + def test_adamw_op(self): + paddle.enable_static() + place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \ + else fluid.CPUPlace() + train_prog = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(train_prog, startup): + with fluid.unique_name.guard(): + x = fluid.data(name='x', shape=[None, 10], dtype='float32') + y = fluid.data(name='y', shape=[None, 1], dtype='float32') + + fc1 = fluid.layers.fc(input=x, size=32, act=None) + prediction = fluid.layers.fc(input=fc1, size=1, act=None) + cost = fluid.layers.square_error_cost(input=prediction, label=y) + avg_cost = fluid.layers.mean(cost) + + simple_lr_fun = partial( + simple_lr_setting, decay_rate=0.8, n_layers=2) + + beta1 = fluid.layers.create_global_var( + shape=[1], value=0.85, dtype='float32', persistable=True) + beta2 = fluid.layers.create_global_var( + shape=[1], value=0.95, dtype='float32', persistable=True) + betas = [beta1, beta2] + opt = paddle.optimizer.AdamW( + learning_rate=1e-5, + beta1=beta1, + beta2=beta2, + weight_decay=0.01, + epsilon=1e-8, + lr_ratio=simple_lr_fun) + opt.minimize(avg_cost) + + exe = fluid.Executor(place) + exe.run(startup) + for _ in range(2): + inputs = np.random.random(size=[8, 10]).astype('float32') + outputs = np.random.random(size=[8, 1]).astype('float32') + rets = exe.run(train_prog, + feed={"x": inputs, + "y": outputs}, + fetch_list=[avg_cost]) + assert rets[0] is not None + + paddle.disable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 158d0870965..0efc40d3300 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -18,6 +18,7 @@ from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable from ..fluid.dygraph import base as imperative_base +from collections import Callable import paddle _C_ops = core.ops @@ -63,6 +64,10 @@ class AdamW(Adam): epsilon (float, optional): A small float value for numerical stability. The default value is 1e-08. weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01. + lr_ratio (function|None, optional): If it is not None, + the learning rate will be updated with layerwise learning rate ratio. + Otherwise, the learning rate is the original. + Default: None. apply_decay_param_fun (function|None, optional): If it is not None, only tensors that makes apply_decay_param_fun(Tensor.name)==True will be updated with weight decay. It only works when we want to specify tensors. @@ -140,6 +145,7 @@ class AdamW(Adam): epsilon=1e-8, parameters=None, weight_decay=0.01, + lr_ratio=None, apply_decay_param_fun=None, grad_clip=None, lazy_mode=False, @@ -163,6 +169,12 @@ class AdamW(Adam): self._apply_decay_param_fun = apply_decay_param_fun self._coeff = coeff self._lr_to_coeff = dict() + if lr_ratio is not None: + assert isinstance(lr_ratio, Callable) + if core.is_compiled_with_xpu() or core.is_compiled_with_npu(): + raise NotImplementedError( + "'lr_ratio' is unimplemented in XPU and NPU") + self._lr_ratio = lr_ratio super(AdamW, self).__init__( learning_rate=learning_rate, @@ -278,6 +290,8 @@ class AdamW(Adam): # create the adamw optimize op if framework.in_dygraph_mode(): + lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio( + param_and_grad[0]) _beta1 = self._beta1 if not isinstance( self._beta1, Variable) else self._beta1.numpy().item(0) @@ -288,7 +302,8 @@ class AdamW(Adam): beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', - 1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff) + 1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff, + "lr_ratio", lr_ratio_) return None @@ -321,6 +336,8 @@ class AdamW(Adam): "multi_precision": find_master, "with_decay": with_decay, "coeff": self._coeff, + "lr_ratio": 1. + if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) } if isinstance(self._beta1, Variable): -- GitLab