learning_rate_scheduler.py 8.5 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23 24 25 26 27
import control_flow
import nn
import ops
import tensor
from ..initializer import init_on_cpu
Q
Qiao Longfei 已提交
28

29 30
__all__ = [
    'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
31
    'polynomial_decay', 'piecewise_decay', 'noam_decay'
32
]
Q
Qiao Longfei 已提交
33 34


35
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
36
    # the first global step is zero in learning rate decay
37
    global_step = nn.autoincreased_step_counter(
38
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1)
39
    global_step = tensor.cast(global_step, 'float32')
Y
Yu Yang 已提交
40 41 42
    return global_step


43
def noam_decay(d_model, warmup_steps):
Y
yuyang18 已提交
44 45 46 47 48 49 50 51 52 53
    """
    Noam decay method. The numpy implementation of noam decay as follows.

    >>> import numpy as np
    >>> lr_value = np.power(d_model, -0.5) * np.min([
    >>>                         np.power(current_steps, -0.5),
    >>>                         np.power(warmup_steps, -1.5) * current_steps])

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
54 55 56

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
57

58 59 60 61 62 63 64 65 66 67 68 69 70 71
        warmup_steps(Variable): A super parameter.

    Returns:
        The decayed learning rate.
    """
    global_step = _decay_step_counter(1)
    with init_on_cpu():
        a = global_step**-0.5
        b = (warmup_steps**-1.5) * global_step
        lr_value = (d_model**-0.5) * ops.elementwise_min(a, b)

    return lr_value


Y
Yu Yang 已提交
72
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Q
Qiao Longfei 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    """Applies exponential decay to the learning rate.

    ```python
    decayed_learning_rate = learning_rate *
            decay_rate ^ (global_step / decay_steps)
    ```
    Args:
        learning_rate: A scalar float32 value or a Variable. This
          will be the initial learning rate during training
        decay_steps: A Python `int32` number.
        decay_rate: A Python `float` number.
        staircase: Boolean. If set true, decay the learning rate every decay_steps.

    Returns:
        The decayed learning rate
    """
Y
Yu Yang 已提交
89
    global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
90

91 92 93 94
    with init_on_cpu():
        # update learning_rate
        div_res = global_step / decay_steps
        if staircase:
95
            div_res = ops.floor(div_res)
96 97 98
        decayed_lr = learning_rate * (decay_rate**div_res)

    return decayed_lr
Q
Qiao Longfei 已提交
99 100


Y
Yu Yang 已提交
101
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Q
Qiao Longfei 已提交
102 103
    """Applies natural exponential decay to the initial learning rate.

Y
Yu Yang 已提交
104 105 106 107 108
    >>> if not staircase:
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
    >>> else:
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))

Q
Qiao Longfei 已提交
109 110 111 112 113 114 115 116 117 118
    Args:
        learning_rate: A scalar float32 value or a Variable. This
          will be the initial learning rate during training
        decay_steps: A Python `int32` number.
        decay_rate: A Python `float` number.
        staircase: Boolean. If set true, decay the learning rate every decay_steps.

    Returns:
        The decayed learning rate
    """
Y
Yu Yang 已提交
119
    global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
120

121 122 123
    with init_on_cpu():
        div_res = global_step / decay_steps
        if staircase:
124 125
            div_res = ops.floor(div_res)
        decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
126 127

    return decayed_lr
Q
Qiao Longfei 已提交
128 129


Y
Yu Yang 已提交
130
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Q
Qiao Longfei 已提交
131 132
    """Applies inverse time decay to the initial learning rate.

Y
Yu Yang 已提交
133 134 135 136 137
    >>> if staircase:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
138 139
    Args:
        learning_rate: A scalar float32 value or a Variable. This
Y
Yu Yang 已提交
140
          will be the initial learning rate during training.
Q
Qiao Longfei 已提交
141 142 143 144 145 146 147
        decay_steps: A Python `int32` number.
        decay_rate: A Python `float` number.
        staircase: Boolean. If set true, decay the learning rate every decay_steps.

    Returns:
        The decayed learning rate
    """
Y
Yu Yang 已提交
148
    global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
149

150 151 152
    with init_on_cpu():
        div_res = global_step / decay_steps
        if staircase:
153
            div_res = ops.floor(div_res)
154 155

        decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
156

157
    return decayed_lr
158 159 160 161 162 163 164


def polynomial_decay(learning_rate,
                     decay_steps,
                     end_learning_rate=0.0001,
                     power=1.0,
                     cycle=False):
Q
qiaolongfei 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177
    """
    **polynomial_decay**

    Applies polynomial decay to the initial learning rate.

    .. code-block::python

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
178 179

    Args:
Q
qiaolongfei 已提交
180
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
181
          will be the initial learning rate during training
Q
qiaolongfei 已提交
182 183 184 185
        decay_steps(int32): A Python `int32` number.
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number
        cycle(bool, Default False): Boolean. If set true, decay the learning rate every decay_steps.
186 187 188 189

    Returns:
        The decayed learning rate
    """
Y
Yu Yang 已提交
190
    global_step = _decay_step_counter()
191

192 193
    with init_on_cpu():
        if cycle:
194 195
            div_res = ops.ceil(global_step / decay_steps)
            zero_var = tensor.fill_constant(
196
                shape=[1], dtype='float32', value=0.0)
197
            one_var = tensor.fill_constant(
198 199
                shape=[1], dtype='float32', value=1.0)

200
            with control_flow.Switch() as switch:
201
                with switch.case(global_step == zero_var):
202
                    tensor.assign(input=one_var, output=div_res)
203 204
            decay_steps = decay_steps * div_res
        else:
205
            decay_steps_var = tensor.fill_constant(
206
                shape=[1], dtype='float32', value=float(decay_steps))
207
            global_step = ops.elementwise_min(x=global_step, y=decay_steps_var)
208 209 210 211

        decayed_lr = (learning_rate - end_learning_rate) * \
                     ((1 - global_step / decay_steps) ** power) + end_learning_rate
    return decayed_lr
212 213


Y
Yu Yang 已提交
214
def piecewise_decay(boundaries, values):
215 216
    """Applies piecewise decay to the initial learning rate.

Y
Yu Yang 已提交
217 218 219 220 221 222 223 224 225
    >>> boundaries = [10000, 20000]
    >>> values = [1.0, 0.5, 0.1]
    >>>
    >>> if step < 10000:
    >>>     learning_rate = 1.0
    >>> elif 10000 <= step < 20000:
    >>>     learning_rate = 0.5
    >>> else:
    >>>     learning_rate = 0.1
226 227 228 229 230
    """

    if len(values) - len(boundaries) != 1:
        raise ValueError("len(values) - len(boundaries) should be 1")

Y
Yu Yang 已提交
231
    global_step = _decay_step_counter()
232

233
    with init_on_cpu():
234
        lr = tensor.create_global_var(
235 236 237 238 239 240
            shape=[1],
            value=0.0,
            dtype='float32',
            persistable=True,
            name="learning_rate")

241
        with control_flow.Switch() as switch:
242
            for i in range(len(boundaries)):
243
                boundary_val = tensor.fill_constant(
244
                    shape=[1], dtype='float32', value=float(boundaries[i]))
245
                value_var = tensor.fill_constant(
246
                    shape=[1], dtype='float32', value=float(values[i]))
247
                with switch.case(global_step < boundary_val):
248 249
                    tensor.assign(value_var, lr)
            last_value_var = tensor.fill_constant(
250 251 252 253
                shape=[1],
                dtype='float32',
                value=float(values[len(values) - 1]))
            with switch.default():
254
                tensor.assign(last_value_var, lr)
255 256

    return lr