learning_rate.py 13.6 KB
Newer Older
S
shippingwang 已提交
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
S
shippingwang 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
S
shippingwang 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

D
dongshuilong 已提交
15 16 17
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

L
littletomatodonkey 已提交
18
from paddle.optimizer import lr
D
dongshuilong 已提交
19
from paddle.optimizer.lr import LRScheduler
W
WuHaobo 已提交
20

21 22
from ppcls.utils import logger

W
WuHaobo 已提交
23

L
littletomatodonkey 已提交
24
class Linear(object):
W
WuHaobo 已提交
25
    """
L
littletomatodonkey 已提交
26
    Linear learning rate decay
W
WuHaobo 已提交
27
    Args:
L
littletomatodonkey 已提交
28 29 30 31
        lr (float): The initial learning rate. It is a python float number.
        epochs(int): The decay step size. It determines the decay cycle.
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
        power(float, optional): Power of polynomial. Default: 1.0.
32 33
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
L
littletomatodonkey 已提交
34
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
W
WuHaobo 已提交
35 36
    """

L
littletomatodonkey 已提交
37 38 39 40 41 42 43
    def __init__(self,
                 learning_rate,
                 epochs,
                 step_each_epoch,
                 end_lr=0.0,
                 power=1.0,
                 warmup_epoch=0,
44
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
45 46
                 last_epoch=-1,
                 **kwargs):
47 48 49 50 51
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
L
littletomatodonkey 已提交
52
        self.learning_rate = learning_rate
53
        self.steps = (epochs - warmup_epoch) * step_each_epoch
L
littletomatodonkey 已提交
54 55 56
        self.end_lr = end_lr
        self.power = power
        self.last_epoch = last_epoch
57 58
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
        self.warmup_start_lr = warmup_start_lr
W
WuHaobo 已提交
59

L
littletomatodonkey 已提交
60 61 62
    def __call__(self):
        learning_rate = lr.PolynomialDecay(
            learning_rate=self.learning_rate,
63
            decay_steps=self.steps,
L
littletomatodonkey 已提交
64 65
            end_lr=self.end_lr,
            power=self.power,
66 67
            last_epoch=self.
            last_epoch) if self.steps > 0 else self.learning_rate
68
        if self.warmup_steps > 0:
L
littletomatodonkey 已提交
69 70
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
71 72
                warmup_steps=self.warmup_steps,
                start_lr=self.warmup_start_lr,
L
littletomatodonkey 已提交
73 74 75 76 77 78
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate


class Cosine(object):
W
WuHaobo 已提交
79
    """
L
littletomatodonkey 已提交
80 81
    Cosine learning rate decay
    lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
W
WuHaobo 已提交
82 83 84 85
    Args:
        lr(float): initial learning rate
        step_each_epoch(int): steps each epoch
        epochs(int): total training epochs
86 87 88
        eta_min(float): Minimum learning rate. Default: 0.0.
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
L
littletomatodonkey 已提交
89
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
W
WuHaobo 已提交
90 91
    """

L
littletomatodonkey 已提交
92 93 94 95
    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
96
                 eta_min=0.0,
L
littletomatodonkey 已提交
97
                 warmup_epoch=0,
98
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
99 100
                 last_epoch=-1,
                 **kwargs):
101 102 103 104 105
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
L
littletomatodonkey 已提交
106
        self.learning_rate = learning_rate
107 108
        self.T_max = (epochs - warmup_epoch) * step_each_epoch
        self.eta_min = eta_min
L
littletomatodonkey 已提交
109
        self.last_epoch = last_epoch
110 111
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
        self.warmup_start_lr = warmup_start_lr
W
WuHaobo 已提交
112

L
littletomatodonkey 已提交
113 114 115 116
    def __call__(self):
        learning_rate = lr.CosineAnnealingDecay(
            learning_rate=self.learning_rate,
            T_max=self.T_max,
117
            eta_min=self.eta_min,
118 119
            last_epoch=self.
            last_epoch) if self.T_max > 0 else self.learning_rate
120
        if self.warmup_steps > 0:
L
littletomatodonkey 已提交
121 122
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
123 124
                warmup_steps=self.warmup_steps,
                start_lr=self.warmup_start_lr,
L
littletomatodonkey 已提交
125 126 127 128 129 130
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate


class Step(object):
S
shippingwang 已提交
131
    """
L
littletomatodonkey 已提交
132
    Piecewise learning rate decay
S
shippingwang 已提交
133 134
    Args:
        step_each_epoch(int): steps each epoch
L
littletomatodonkey 已提交
135 136 137 138
        learning_rate (float): The initial learning rate. It is a python float number.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
            It should be less than 1.0. Default: 0.1.
139 140
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
L
littletomatodonkey 已提交
141
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
S
shippingwang 已提交
142 143
    """

S
shippingwang 已提交
144
    def __init__(self,
L
littletomatodonkey 已提交
145 146
                 learning_rate,
                 step_size,
S
shippingwang 已提交
147
                 step_each_epoch,
148
                 epochs,
L
littletomatodonkey 已提交
149 150
                 gamma,
                 warmup_epoch=0,
151
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
152
                 last_epoch=-1,
S
shippingwang 已提交
153
                 **kwargs):
154 155 156 157 158
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
L
littletomatodonkey 已提交
159 160 161 162
        self.step_size = step_each_epoch * step_size
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.last_epoch = last_epoch
163 164
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
        self.warmup_start_lr = warmup_start_lr
S
shippingwang 已提交
165

L
littletomatodonkey 已提交
166 167 168 169 170 171
    def __call__(self):
        learning_rate = lr.StepDecay(
            learning_rate=self.learning_rate,
            step_size=self.step_size,
            gamma=self.gamma,
            last_epoch=self.last_epoch)
172
        if self.warmup_steps > 0:
L
littletomatodonkey 已提交
173 174
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
175 176
                warmup_steps=self.warmup_steps,
                start_lr=self.warmup_start_lr,
L
littletomatodonkey 已提交
177 178 179 180 181 182
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate


class Piecewise(object):
W
WuHaobo 已提交
183
    """
L
littletomatodonkey 已提交
184
    Piecewise learning rate decay
W
WuHaobo 已提交
185
    Args:
L
littletomatodonkey 已提交
186 187 188
        boundaries(list): A list of steps numbers. The type of element in the list is python int.
        values(list): A list of learning rate values that will be picked during different epoch boundaries.
            The type of element in the list is python float.
189 190
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
L
littletomatodonkey 已提交
191
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
W
WuHaobo 已提交
192 193 194
    """

    def __init__(self,
L
littletomatodonkey 已提交
195 196 197
                 step_each_epoch,
                 decay_epochs,
                 values,
198
                 epochs,
L
littletomatodonkey 已提交
199
                 warmup_epoch=0,
200
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
201 202
                 last_epoch=-1,
                 **kwargs):
203 204 205 206 207
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
L
littletomatodonkey 已提交
208 209 210
        self.boundaries = [step_each_epoch * e for e in decay_epochs]
        self.values = values
        self.last_epoch = last_epoch
211 212
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
        self.warmup_start_lr = warmup_start_lr
W
WuHaobo 已提交
213 214

    def __call__(self):
L
littletomatodonkey 已提交
215 216 217 218
        learning_rate = lr.PiecewiseDecay(
            boundaries=self.boundaries,
            values=self.values,
            last_epoch=self.last_epoch)
219
        if self.warmup_steps > 0:
L
littletomatodonkey 已提交
220 221
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
222 223
                warmup_steps=self.warmup_steps,
                start_lr=self.warmup_start_lr,
L
littletomatodonkey 已提交
224 225 226
                end_lr=self.values[0],
                last_epoch=self.last_epoch)
        return learning_rate
D
dongshuilong 已提交
227 228 229 230 231


class MultiStepDecay(LRScheduler):
    """
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
232
    The algorithm can be described as the code below.
D
dongshuilong 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245
    .. code-block:: text
        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005
    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
246
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
D
dongshuilong 已提交
247 248 249
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
250

D
dongshuilong 已提交
251 252 253
    Returns:
        ``MultiStepDecay`` instance to schedule learning rate.
    Examples:
254

D
dongshuilong 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        .. code-block:: python
            import paddle
            import numpy as np
            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)
            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(5):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 epochs,
                 step_each_epoch,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))
        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')
        self.milestones = [x * step_each_epoch for x in milestones]
        self.gamma = gamma
320
        super().__init__(learning_rate, last_epoch, verbose)
D
dongshuilong 已提交
321 322 323 324 325 326

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
        return self.base_lr * (self.gamma**len(self.milestones))