From 35483a209480094c3ca8c72285a1249ada5db6c4 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 18 Apr 2018 13:20:35 +0800 Subject: [PATCH] Add neural transformer leanring rate decay function. (#9951) Add neural transformer leanring rate decay function --- .../fluid/layers/learning_rate_scheduler.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 65b95a58d6..d13c54daa5 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -20,7 +20,7 @@ from ..initializer import init_on_cpu __all__ = [ 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', - 'polynomial_decay', 'piecewise_decay' + 'polynomial_decay', 'piecewise_decay', 'noam_decay' ] """ When training a model, it's often useful to decay the @@ -32,14 +32,41 @@ strategy according to this module. """ -def _decay_step_counter(): +def _decay_step_counter(begin=0): # the first global step is zero in learning rate decay global_step = nn.autoincreased_step_counter( - counter_name='@LR_DECAY_COUNTER@', begin=0, step=1) + counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1) global_step = tensor.cast(global_step, 'float32') return global_step +def noam_decay(d_model, warmup_steps): + """Apply decay to learning rate. + ```python + lr_value = np.power(d_model, -0.5) * np.min([ + np.power(current_steps, -0.5), + np.power(warmup_steps, -1.5) * current_steps + ]) + ``` + + Args: + d_model(Variable): The dimensionality of input and output of model. + Reference: attention is all you need + https://arxiv.org/pdf/1706.03762.pdf + warmup_steps(Variable): A super parameter. + + Returns: + The decayed learning rate. + """ + global_step = _decay_step_counter(1) + with init_on_cpu(): + a = global_step**-0.5 + b = (warmup_steps**-1.5) * global_step + lr_value = (d_model**-0.5) * ops.elementwise_min(a, b) + + return lr_value + + def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): """Applies exponential decay to the learning rate. -- GitLab