From 1ebd7434d545f8c439792468298f1108b631668e Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Sun, 31 Mar 2019 15:00:00 +0800 Subject: [PATCH] Add linear learning warmup method in learning rate scheduler. (#16563) * Add linear learning warmup method This warmup lr can be combinated with other learning rate strategies. For example: decayed_lr = fluid.layers.linear_lr_warmup( fluid.layers.piecewise_decay(boundaries, lr_steps), warmup_steps, start_lr, end_lr) --- paddle/fluid/API.spec | 1 + .../fluid/layers/learning_rate_scheduler.py | 58 ++++++++++++++++++- .../unittests/test_learning_rate_scheduler.py | 47 ++++++++++++++- 3 files changed, 102 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e1d20051b49..54fb8016f5b 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -359,6 +359,7 @@ paddle.fluid.layers.piecewise_decay (ArgSpec(args=['boundaries', 'values'], vara paddle.fluid.layers.noam_decay (ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None), ('document', 'd9a95746353fd574be36dc28d8726c28')) paddle.fluid.layers.append_LARS (ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None), ('document', 'd24fa1e7d62ac8a534fc6a86002f84f8')) paddle.fluid.layers.cosine_decay (ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None), ('document', '9588c64c26ffaef3c466e404a6af9d9b')) +paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', '2ef3f5ca5cd71ea4217c418e5a7a0565')) paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.contrib.StateCell.__init__ (ArgSpec(args=['self', 'inputs', 'states', 'out_state', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.contrib.StateCell.compute_state (ArgSpec(args=['self', 'inputs'], varargs=None, keywords=None, defaults=None), ('document', '92973b3f222081a1d17069c683cf4a99')) diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 378aeb37605..be842622977 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -33,7 +33,7 @@ import math __all__ = [ 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', 'polynomial_decay', 'piecewise_decay', 'noam_decay', 'append_LARS', - 'cosine_decay' + 'cosine_decay', 'linear_lr_warmup' ] @@ -383,3 +383,59 @@ def append_LARS(params_grads, learning_rate, weight_decay): / _balanced_weight(param_norm, grad_norm) # set back param local learning rate param.optimize_attr['learning_rate'] = decayed_lr + + +def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): + """ + Applies linear learning rate warmup before the normal learning rate + scheduling. + + .. code-block:: python + + if global_step < warmup_steps: + linear_step = end_lr - start_lr + lr = start_lr + linear_step * (global_step / warmup_steps) + + Args: + learning_rate (float | Variable): A float value or Variable. + warmup_steps (int): The warmup steps. + start_lr (float): The start learning of warmup. + end_lr (float): The end learning of warmup. + + Returns: + The decayed learning rate in warmup period. + + Examples: + .. code-block:: python + + boundaries = [100, 200] + lr_steps = [0.1, 0.01, 0.001] + warmup_steps = 50 + start_lr = 1. / 3. + end_lr = 0.1 + decayed_lr = fluid.layers.linear_lr_warmup( + fluid.layers.piecewise_decay(boundaries, lr_steps), + warmup_steps, start_lr, end_lr) + + """ + assert (isinstance(end_lr, float)) + assert (isinstance(start_lr, float)) + linear_step = end_lr - start_lr + with default_main_program()._lr_schedule_guard(): + lr = tensor.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name="learning_rate_warmup") + + global_step = _decay_step_counter() + + with control_flow.Switch() as switch: + with switch.case(global_step < warmup_steps): + decayed_lr = start_lr + linear_step * (global_step / + float(warmup_steps)) + tensor.assign(decayed_lr, lr) + with switch.default(): + tensor.assign(learning_rate, lr) + return lr diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 5212d97dfbc..2108c2a9f53 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -120,9 +120,9 @@ class TestLearningRateDecay(unittest.TestCase): self.assertAlmostEqual( python_decayed_lr, lr_val[0], - msg='Failed fn is {0}, Python result is {1}, Fluid result is {2}'. + msg='Failed lr scheduler is {0}, step {1}, Python result is {2}, Fluid result is {3}'. format(python_decay_fn.__name__, - str(python_decayed_lr), str(lr_val[0]))) + str(step), str(python_decayed_lr), str(lr_val[0]))) def test_decay(self): common_kwargs_true = { @@ -164,12 +164,53 @@ class TestLearningRateDecay(unittest.TestCase): ] for py_decay_fn, fluid_decay_fn, kwargs in decay_fns: - print("decay_fn=" + py_decay_fn.__name__ + " kwargs=" + str(kwargs)) + print("class=" + self.__class__.__name__ + "decay_fn=" + + py_decay_fn.__name__ + " kwargs=" + str(kwargs)) main_program = framework.Program() startup_program = framework.Program() with framework.program_guard(main_program, startup_program): self.check_decay(py_decay_fn, fluid_decay_fn, kwargs) +def linear_lr_warmup(global_step, warmup_steps, start_lr, end_lr): + linear_step = end_lr - start_lr + decayed_lr = start_lr + linear_step * (global_step / warmup_steps) + return decayed_lr + + +class TestLinearWamrupLearningRateDecay(TestLearningRateDecay): + def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn, + kwargs): + main_prog = fluid.Program() + startup_prog = fluid.Program() + + warmup_steps = 10 + start_lr = 1. / 3. + end_lr = 0.1 + + with fluid.program_guard(main_prog, startup_prog): + decayed_lr = layers.linear_lr_warmup( + fluid_decay_fn(**kwargs), warmup_steps, start_lr, end_lr) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + for step in range(20): + lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr]) + if step < warmup_steps: + python_decayed_lr = linear_lr_warmup( + float(step), warmup_steps, start_lr, end_lr) + else: + python_decayed_lr = python_decay_fn( + global_step=float(step), **kwargs) + self.assertAlmostEqual( + python_decayed_lr, + lr_val[0], + msg='Test {0} Failed, step {1}, Python result is {2}, Fluid result is {3}'. + format(python_decay_fn.__name__, + str(step), str(python_decayed_lr), str(lr_val[0]))) + + if __name__ == '__main__': unittest.main() -- GitLab