learning_rate.py 15.0 KB
Newer Older
S
shippingwang 已提交
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
S
shippingwang 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
S
shippingwang 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

D
dongshuilong 已提交
15 16 17
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

L
littletomatodonkey 已提交
18
from paddle.optimizer import lr
D
dongshuilong 已提交
19
from paddle.optimizer.lr import LRScheduler
W
WuHaobo 已提交
20

21 22
from ppcls.utils import logger

W
WuHaobo 已提交
23

L
littletomatodonkey 已提交
24
class Linear(object):
W
WuHaobo 已提交
25
    """
L
littletomatodonkey 已提交
26
    Linear learning rate decay
W
WuHaobo 已提交
27
    Args:
L
littletomatodonkey 已提交
28 29 30 31
        lr (float): The initial learning rate. It is a python float number.
        epochs(int): The decay step size. It determines the decay cycle.
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
        power(float, optional): Power of polynomial. Default: 1.0.
32 33
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
L
littletomatodonkey 已提交
34
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
W
WuHaobo 已提交
35 36
    """

L
littletomatodonkey 已提交
37 38 39 40 41 42 43
    def __init__(self,
                 learning_rate,
                 epochs,
                 step_each_epoch,
                 end_lr=0.0,
                 power=1.0,
                 warmup_epoch=0,
44
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
45 46
                 last_epoch=-1,
                 **kwargs):
47 48 49 50 51
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
L
littletomatodonkey 已提交
52
        self.learning_rate = learning_rate
53
        self.steps = (epochs - warmup_epoch) * step_each_epoch
L
littletomatodonkey 已提交
54 55 56
        self.end_lr = end_lr
        self.power = power
        self.last_epoch = last_epoch
57 58
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
        self.warmup_start_lr = warmup_start_lr
W
WuHaobo 已提交
59

L
littletomatodonkey 已提交
60 61 62
    def __call__(self):
        learning_rate = lr.PolynomialDecay(
            learning_rate=self.learning_rate,
63
            decay_steps=self.steps,
L
littletomatodonkey 已提交
64 65
            end_lr=self.end_lr,
            power=self.power,
66 67
            last_epoch=self.
            last_epoch) if self.steps > 0 else self.learning_rate
68
        if self.warmup_steps > 0:
L
littletomatodonkey 已提交
69 70
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
71 72
                warmup_steps=self.warmup_steps,
                start_lr=self.warmup_start_lr,
L
littletomatodonkey 已提交
73 74 75 76 77
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate


78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
class Constant(LRScheduler):
    """
    Constant learning rate
    Args:
        lr (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self, learning_rate, last_epoch=-1, **kwargs):
        self.learning_rate = learning_rate
        self.last_epoch = last_epoch
        super().__init__()

    def get_lr(self):
        return self.learning_rate


L
littletomatodonkey 已提交
95
class Cosine(object):
W
WuHaobo 已提交
96
    """
L
littletomatodonkey 已提交
97 98
    Cosine learning rate decay
    lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
W
WuHaobo 已提交
99 100 101 102
    Args:
        lr(float): initial learning rate
        step_each_epoch(int): steps each epoch
        epochs(int): total training epochs
103 104 105
        eta_min(float): Minimum learning rate. Default: 0.0.
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
L
littletomatodonkey 已提交
106
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
W
WuHaobo 已提交
107 108
    """

L
littletomatodonkey 已提交
109 110 111 112
    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
113
                 eta_min=0.0,
L
littletomatodonkey 已提交
114
                 warmup_epoch=0,
115
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
116 117
                 last_epoch=-1,
                 **kwargs):
118 119 120 121 122
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
L
littletomatodonkey 已提交
123
        self.learning_rate = learning_rate
124 125
        self.T_max = (epochs - warmup_epoch) * step_each_epoch
        self.eta_min = eta_min
L
littletomatodonkey 已提交
126
        self.last_epoch = last_epoch
127 128
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
        self.warmup_start_lr = warmup_start_lr
W
WuHaobo 已提交
129

L
littletomatodonkey 已提交
130 131 132 133
    def __call__(self):
        learning_rate = lr.CosineAnnealingDecay(
            learning_rate=self.learning_rate,
            T_max=self.T_max,
134
            eta_min=self.eta_min,
135 136
            last_epoch=self.
            last_epoch) if self.T_max > 0 else self.learning_rate
137
        if self.warmup_steps > 0:
L
littletomatodonkey 已提交
138 139
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
140 141
                warmup_steps=self.warmup_steps,
                start_lr=self.warmup_start_lr,
L
littletomatodonkey 已提交
142 143 144 145 146 147
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate


class Step(object):
S
shippingwang 已提交
148
    """
L
littletomatodonkey 已提交
149
    Piecewise learning rate decay
S
shippingwang 已提交
150 151
    Args:
        step_each_epoch(int): steps each epoch
L
littletomatodonkey 已提交
152 153 154 155
        learning_rate (float): The initial learning rate. It is a python float number.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
            It should be less than 1.0. Default: 0.1.
156 157
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
L
littletomatodonkey 已提交
158
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
S
shippingwang 已提交
159 160
    """

S
shippingwang 已提交
161
    def __init__(self,
L
littletomatodonkey 已提交
162 163
                 learning_rate,
                 step_size,
S
shippingwang 已提交
164
                 step_each_epoch,
165
                 epochs,
L
littletomatodonkey 已提交
166 167
                 gamma,
                 warmup_epoch=0,
168
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
169
                 last_epoch=-1,
S
shippingwang 已提交
170
                 **kwargs):
171 172 173 174 175
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
L
littletomatodonkey 已提交
176 177 178 179
        self.step_size = step_each_epoch * step_size
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.last_epoch = last_epoch
180 181
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
        self.warmup_start_lr = warmup_start_lr
S
shippingwang 已提交
182

L
littletomatodonkey 已提交
183 184 185 186 187 188
    def __call__(self):
        learning_rate = lr.StepDecay(
            learning_rate=self.learning_rate,
            step_size=self.step_size,
            gamma=self.gamma,
            last_epoch=self.last_epoch)
189
        if self.warmup_steps > 0:
L
littletomatodonkey 已提交
190 191
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
192 193
                warmup_steps=self.warmup_steps,
                start_lr=self.warmup_start_lr,
L
littletomatodonkey 已提交
194 195 196 197 198 199
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate


class Piecewise(object):
W
WuHaobo 已提交
200
    """
L
littletomatodonkey 已提交
201
    Piecewise learning rate decay
W
WuHaobo 已提交
202
    Args:
L
littletomatodonkey 已提交
203 204 205
        boundaries(list): A list of steps numbers. The type of element in the list is python int.
        values(list): A list of learning rate values that will be picked during different epoch boundaries.
            The type of element in the list is python float.
206 207
        warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
        warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
H
HydrogenSulfate 已提交
208
        by_epoch(bool): Whether lr decay by epoch. Default: False.
L
littletomatodonkey 已提交
209
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
W
WuHaobo 已提交
210 211 212
    """

    def __init__(self,
L
littletomatodonkey 已提交
213 214 215
                 step_each_epoch,
                 decay_epochs,
                 values,
216
                 epochs,
L
littletomatodonkey 已提交
217
                 warmup_epoch=0,
218
                 warmup_start_lr=0.0,
H
HydrogenSulfate 已提交
219
                 by_epoch=False,
L
littletomatodonkey 已提交
220 221
                 last_epoch=-1,
                 **kwargs):
222 223 224 225 226
        super().__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
H
HydrogenSulfate 已提交
227 228
        self.boundaries_steps = [step_each_epoch * e for e in decay_epochs]
        self.boundaries_epoch = decay_epochs
L
littletomatodonkey 已提交
229 230
        self.values = values
        self.last_epoch = last_epoch
231
        self.warmup_steps = round(warmup_epoch * step_each_epoch)
H
HydrogenSulfate 已提交
232
        self.warmup_epoch = warmup_epoch
233
        self.warmup_start_lr = warmup_start_lr
H
HydrogenSulfate 已提交
234
        self.by_epoch = by_epoch
W
WuHaobo 已提交
235 236

    def __call__(self):
H
HydrogenSulfate 已提交
237
        if self.by_epoch:
H
HydrogenSulfate 已提交
238
            learning_rate = lr.PiecewiseDecay(
H
HydrogenSulfate 已提交
239
                boundaries=self.boundaries_epoch,
H
HydrogenSulfate 已提交
240
                values=self.values,
L
littletomatodonkey 已提交
241
                last_epoch=self.last_epoch)
H
HydrogenSulfate 已提交
242
            if self.warmup_epoch > 0:
H
HydrogenSulfate 已提交
243 244
                learning_rate = lr.LinearWarmup(
                    learning_rate=learning_rate,
H
HydrogenSulfate 已提交
245
                    warmup_steps=self.warmup_epoch,
H
HydrogenSulfate 已提交
246 247 248 249 250
                    start_lr=self.warmup_start_lr,
                    end_lr=self.values[0],
                    last_epoch=self.last_epoch)
        else:
            learning_rate = lr.PiecewiseDecay(
H
HydrogenSulfate 已提交
251
                boundaries=self.boundaries_steps,
H
HydrogenSulfate 已提交
252
                values=self.values,
L
littletomatodonkey 已提交
253
                last_epoch=self.last_epoch)
H
HydrogenSulfate 已提交
254
            if self.warmup_steps > 0:
H
HydrogenSulfate 已提交
255 256
                learning_rate = lr.LinearWarmup(
                    learning_rate=learning_rate,
H
HydrogenSulfate 已提交
257
                    warmup_steps=self.warmup_steps,
H
HydrogenSulfate 已提交
258 259 260
                    start_lr=self.warmup_start_lr,
                    end_lr=self.values[0],
                    last_epoch=self.last_epoch)
H
HydrogenSulfate 已提交
261
        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
262
        return learning_rate
D
dongshuilong 已提交
263 264 265 266 267


class MultiStepDecay(LRScheduler):
    """
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
268
    The algorithm can be described as the code below.
D
dongshuilong 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281
    .. code-block:: text
        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005
    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
282
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
D
dongshuilong 已提交
283 284 285
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
286

D
dongshuilong 已提交
287 288 289
    Returns:
        ``MultiStepDecay`` instance to schedule learning rate.
    Examples:
290

D
dongshuilong 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
        .. code-block:: python
            import paddle
            import numpy as np
            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)
            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(5):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 epochs,
                 step_each_epoch,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))
        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')
        self.milestones = [x * step_each_epoch for x in milestones]
        self.gamma = gamma
356
        super().__init__(learning_rate, last_epoch, verbose)
D
dongshuilong 已提交
357 358 359 360 361 362

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
        return self.base_lr * (self.gamma**len(self.milestones))