未验证 提交 35483a20 编写于 作者: G gongweibao 提交者: GitHub

Add neural transformer leanring rate decay function. (#9951)

Add neural transformer leanring rate decay function
上级 fbe56247
...@@ -20,7 +20,7 @@ from ..initializer import init_on_cpu ...@@ -20,7 +20,7 @@ from ..initializer import init_on_cpu
__all__ = [ __all__ = [
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
'polynomial_decay', 'piecewise_decay' 'polynomial_decay', 'piecewise_decay', 'noam_decay'
] ]
""" """
When training a model, it's often useful to decay the When training a model, it's often useful to decay the
...@@ -32,14 +32,41 @@ strategy according to this module. ...@@ -32,14 +32,41 @@ strategy according to this module.
""" """
def _decay_step_counter(): def _decay_step_counter(begin=0):
# the first global step is zero in learning rate decay # the first global step is zero in learning rate decay
global_step = nn.autoincreased_step_counter( global_step = nn.autoincreased_step_counter(
counter_name='@LR_DECAY_COUNTER@', begin=0, step=1) counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1)
global_step = tensor.cast(global_step, 'float32') global_step = tensor.cast(global_step, 'float32')
return global_step return global_step
def noam_decay(d_model, warmup_steps):
"""Apply decay to learning rate.
```python
lr_value = np.power(d_model, -0.5) * np.min([
np.power(current_steps, -0.5),
np.power(warmup_steps, -1.5) * current_steps
])
```
Args:
d_model(Variable): The dimensionality of input and output of model.
Reference: attention is all you need
https://arxiv.org/pdf/1706.03762.pdf
warmup_steps(Variable): A super parameter.
Returns:
The decayed learning rate.
"""
global_step = _decay_step_counter(1)
with init_on_cpu():
a = global_step**-0.5
b = (warmup_steps**-1.5) * global_step
lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)
return lr_value
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""Applies exponential decay to the learning rate. """Applies exponential decay to the learning rate.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册