learning_rate.py 14.2 KB
Newer Older
S
shippingwang 已提交
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
S
shippingwang 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
S
shippingwang 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

D
dongshuilong 已提交
15 16 17
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

H
HydrogenSulfate 已提交
18 19
from abc import abstractmethod
from typing import Union
W
WuHaobo 已提交
20

H
HydrogenSulfate 已提交
21
from paddle.optimizer import lr
22 23
from ppcls.utils import logger

W
WuHaobo 已提交
24

H
HydrogenSulfate 已提交
25 26 27 28 29 30 31 32 33 34 35 36
class LRBase(object):
    """Base class for custom learning rates

    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
        verbose (bool): If True, prints a message to stdout for each update. Defaults to False
W
WuHaobo 已提交
37
    """
H
HydrogenSulfate 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

    def __init__(self,
                 epochs: int,
                 step_each_epoch: int,
                 learning_rate: float,
                 warmup_epoch: int,
                 warmup_start_lr: float,
                 last_epoch: int,
                 by_epoch: bool,
                 verbose: bool=False) -> None:
        """Initialize and record the necessary parameters
        """
        super(LRBase, self).__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
        self.epochs = epochs
        self.step_each_epoch = step_each_epoch
        self.learning_rate = learning_rate
        self.warmup_epoch = warmup_epoch
        self.warmup_steps = round(
            self.warmup_epoch *
H
HydrogenSulfate 已提交
61
            self.step_each_epoch) if by_epoch else self.warmup_epoch
H
HydrogenSulfate 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        self.warmup_start_lr = warmup_start_lr
        self.last_epoch = last_epoch
        self.by_epoch = by_epoch
        self.verbose = verbose

    @abstractmethod
    def __call__(self, *kargs, **kwargs) -> lr.LRScheduler:
        """generate an learning rate scheduler

        Returns:
            lr.LinearWarmup: learning rate scheduler
        """
        pass

    def linear_warmup(
            self,
            learning_rate: Union[float, lr.LRScheduler]) -> lr.LinearWarmup:
        """Add an Linear Warmup before learning_rate

        Args:
            learning_rate (Union[float, lr.LRScheduler]): original learning rate without warmup

        Returns:
            lr.LinearWarmup: learning rate scheduler with warmup
        """
        warmup_lr = lr.LinearWarmup(
            learning_rate=learning_rate,
            warmup_steps=self.warmup_steps,
            start_lr=self.warmup_start_lr,
            end_lr=self.learning_rate,
            last_epoch=self.last_epoch,
            verbose=self.verbose)
        return warmup_lr


class Constant(LRBase):
    """Constant learning rate

W
WuHaobo 已提交
100
    Args:
H
HydrogenSulfate 已提交
101 102 103 104 105 106 107
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
W
WuHaobo 已提交
108 109
    """

L
littletomatodonkey 已提交
110
    def __init__(self,
H
HydrogenSulfate 已提交
111 112
                 epochs,
                 step_each_epoch,
L
littletomatodonkey 已提交
113
                 learning_rate,
H
HydrogenSulfate 已提交
114 115 116 117 118
                 warmup_epoch=0,
                 warmup_start_lr=0.0,
                 last_epoch=-1,
                 by_epoch=False,
                 **kwargs):
H
HydrogenSulfate 已提交
119 120 121
        super(Constant, self).__init__(epochs, step_each_epoch, learning_rate,
                                       warmup_epoch, warmup_start_lr,
                                       last_epoch, by_epoch)
H
HydrogenSulfate 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157

    def __call__(self):
        learning_rate = lr.LRScheduler(
            learning_rate=self.learning_rate, last_epoch=self.last_epoch)

        def make_get_lr():
            def get_lr(self):
                return self.learning_rate

            return get_lr

        setattr(learning_rate, "get_lr", make_get_lr())

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate


class Linear(LRBase):
    """Linear learning rate decay

    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        end_lr (float, optional): The minimum final learning rate. Defaults to 0.0.
        power (float, optional): Power of polynomial. Defaults to 1.0.
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
    """

    def __init__(self,
L
littletomatodonkey 已提交
158 159
                 epochs,
                 step_each_epoch,
H
HydrogenSulfate 已提交
160
                 learning_rate,
L
littletomatodonkey 已提交
161 162
                 end_lr=0.0,
                 power=1.0,
H
HydrogenSulfate 已提交
163
                 cycle=False,
L
littletomatodonkey 已提交
164
                 warmup_epoch=0,
165
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
166
                 last_epoch=-1,
H
HydrogenSulfate 已提交
167
                 by_epoch=False,
L
littletomatodonkey 已提交
168
                 **kwargs):
H
HydrogenSulfate 已提交
169 170 171 172
        super(Linear, self).__init__(epochs, step_each_epoch, learning_rate,
                                     warmup_epoch, warmup_start_lr, last_epoch,
                                     by_epoch)
        self.decay_steps = (epochs - self.warmup_epoch) * step_each_epoch
L
littletomatodonkey 已提交
173 174
        self.end_lr = end_lr
        self.power = power
H
HydrogenSulfate 已提交
175 176 177 178
        self.cycle = cycle
        self.warmup_steps = round(self.warmup_epoch * step_each_epoch)
        if self.by_epoch:
            self.decay_steps = self.epochs - self.warmup_epoch
W
WuHaobo 已提交
179

L
littletomatodonkey 已提交
180 181 182
    def __call__(self):
        learning_rate = lr.PolynomialDecay(
            learning_rate=self.learning_rate,
H
HydrogenSulfate 已提交
183
            decay_steps=self.decay_steps,
L
littletomatodonkey 已提交
184 185
            end_lr=self.end_lr,
            power=self.power,
H
HydrogenSulfate 已提交
186
            cycle=self.cycle,
187
            last_epoch=self.
H
HydrogenSulfate 已提交
188
            last_epoch) if self.decay_steps > 0 else self.learning_rate
L
littletomatodonkey 已提交
189

H
HydrogenSulfate 已提交
190 191
        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)
L
littletomatodonkey 已提交
192

H
HydrogenSulfate 已提交
193 194
        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate
195 196


H
HydrogenSulfate 已提交
197 198
class Cosine(LRBase):
    """Cosine learning rate decay
199

H
HydrogenSulfate 已提交
200
    ``lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)``
201

W
WuHaobo 已提交
202
    Args:
H
HydrogenSulfate 已提交
203 204 205 206 207 208 209 210
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
W
WuHaobo 已提交
211 212
    """

L
littletomatodonkey 已提交
213 214
    def __init__(self,
                 epochs,
H
HydrogenSulfate 已提交
215 216
                 step_each_epoch,
                 learning_rate,
217
                 eta_min=0.0,
L
littletomatodonkey 已提交
218
                 warmup_epoch=0,
219
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
220
                 last_epoch=-1,
H
HydrogenSulfate 已提交
221
                 by_epoch=False,
L
littletomatodonkey 已提交
222
                 **kwargs):
H
HydrogenSulfate 已提交
223
        super(Cosine, self).__init__(epochs, step_each_epoch, learning_rate,
H
HydrogenSulfate 已提交
224 225 226
                                     warmup_epoch, warmup_start_lr, last_epoch,
                                     by_epoch)
        self.T_max = (self.epochs - self.warmup_epoch) * self.step_each_epoch
227
        self.eta_min = eta_min
H
HydrogenSulfate 已提交
228 229
        if self.by_epoch:
            self.T_max = self.epochs - self.warmup_epoch
W
WuHaobo 已提交
230

L
littletomatodonkey 已提交
231 232 233 234
    def __call__(self):
        learning_rate = lr.CosineAnnealingDecay(
            learning_rate=self.learning_rate,
            T_max=self.T_max,
235
            eta_min=self.eta_min,
236 237
            last_epoch=self.
            last_epoch) if self.T_max > 0 else self.learning_rate
H
HydrogenSulfate 已提交
238

239
        if self.warmup_steps > 0:
H
HydrogenSulfate 已提交
240 241 242
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
243 244 245
        return learning_rate


H
HydrogenSulfate 已提交
246 247 248
class Step(LRBase):
    """Step learning rate decay

S
shippingwang 已提交
249
    Args:
H
HydrogenSulfate 已提交
250 251 252
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
L
littletomatodonkey 已提交
253
        step_size (int): the interval to update.
H
HydrogenSulfate 已提交
254 255 256 257 258
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
S
shippingwang 已提交
259 260
    """

S
shippingwang 已提交
261
    def __init__(self,
H
HydrogenSulfate 已提交
262 263
                 epochs,
                 step_each_epoch,
L
littletomatodonkey 已提交
264 265 266 267
                 learning_rate,
                 step_size,
                 gamma,
                 warmup_epoch=0,
268
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
269
                 last_epoch=-1,
H
HydrogenSulfate 已提交
270
                 by_epoch=False,
S
shippingwang 已提交
271
                 **kwargs):
H
HydrogenSulfate 已提交
272 273 274
        super(Step, self).__init__(epochs, step_each_epoch, learning_rate,
                                   warmup_epoch, warmup_start_lr, last_epoch,
                                   by_epoch)
H
HydrogenSulfate 已提交
275
        self.step_size = step_size * step_each_epoch
L
littletomatodonkey 已提交
276
        self.gamma = gamma
H
HydrogenSulfate 已提交
277 278
        if self.by_epoch:
            self.step_size = step_size
S
shippingwang 已提交
279

L
littletomatodonkey 已提交
280 281 282 283 284 285
    def __call__(self):
        learning_rate = lr.StepDecay(
            learning_rate=self.learning_rate,
            step_size=self.step_size,
            gamma=self.gamma,
            last_epoch=self.last_epoch)
H
HydrogenSulfate 已提交
286

287
        if self.warmup_steps > 0:
H
HydrogenSulfate 已提交
288 289 290
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
291 292 293
        return learning_rate


H
HydrogenSulfate 已提交
294 295 296
class Piecewise(LRBase):
    """Piecewise learning rate decay

W
WuHaobo 已提交
297
    Args:
H
HydrogenSulfate 已提交
298 299 300 301 302 303 304 305
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        decay_epochs (List[int]): A list of steps numbers. The type of element in the list is python int.
        values (List[float]): A list of learning rate values that will be picked during different epoch boundaries.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
W
WuHaobo 已提交
306 307 308
    """

    def __init__(self,
H
HydrogenSulfate 已提交
309
                 epochs,
L
littletomatodonkey 已提交
310 311 312 313
                 step_each_epoch,
                 decay_epochs,
                 values,
                 warmup_epoch=0,
314
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
315
                 last_epoch=-1,
H
HydrogenSulfate 已提交
316
                 by_epoch=False,
L
littletomatodonkey 已提交
317
                 **kwargs):
H
HydrogenSulfate 已提交
318
        super(Piecewise,
H
HydrogenSulfate 已提交
319 320
              self).__init__(epochs, step_each_epoch, values[0], warmup_epoch,
                             warmup_start_lr, last_epoch, by_epoch)
L
littletomatodonkey 已提交
321
        self.values = values
H
HydrogenSulfate 已提交
322 323 324
        self.boundaries_steps = [e * step_each_epoch for e in decay_epochs]
        if self.by_epoch is True:
            self.boundaries_steps = decay_epochs
W
WuHaobo 已提交
325 326

    def __call__(self):
H
HydrogenSulfate 已提交
327 328 329 330 331 332 333 334
        learning_rate = lr.PiecewiseDecay(
            boundaries=self.boundaries_steps,
            values=self.values,
            last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

H
HydrogenSulfate 已提交
335
        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
336
        return learning_rate
D
dongshuilong 已提交
337 338


H
HydrogenSulfate 已提交
339 340 341
class MultiStepDecay(LRBase):
    """MultiStepDecay learning rate decay

D
dongshuilong 已提交
342
    Args:
H
HydrogenSulfate 已提交
343 344 345 346 347 348 349 350 351
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        milestones (List[int]): List of each boundaries. Must be increasing.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Defaults to 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
D
dongshuilong 已提交
352 353 354 355 356
    """

    def __init__(self,
                 epochs,
                 step_each_epoch,
H
HydrogenSulfate 已提交
357 358
                 learning_rate,
                 milestones,
D
dongshuilong 已提交
359
                 gamma=0.1,
H
HydrogenSulfate 已提交
360 361
                 warmup_epoch=0,
                 warmup_start_lr=0.0,
D
dongshuilong 已提交
362
                 last_epoch=-1,
H
HydrogenSulfate 已提交
363 364
                 by_epoch=False,
                 **kwargs):
H
HydrogenSulfate 已提交
365 366 367
        super(MultiStepDecay, self).__init__(
            epochs, step_each_epoch, learning_rate, warmup_epoch,
            warmup_start_lr, last_epoch, by_epoch)
D
dongshuilong 已提交
368 369
        self.milestones = [x * step_each_epoch for x in milestones]
        self.gamma = gamma
H
HydrogenSulfate 已提交
370 371
        if self.by_epoch:
            self.milestones = milestones
D
dongshuilong 已提交
372

H
HydrogenSulfate 已提交
373 374 375 376 377 378 379 380 381 382 383 384
    def __call__(self):
        learning_rate = lr.MultiStepDecay(
            learning_rate=self.learning_rate,
            milestones=self.milestones,
            gamma=self.gamma,
            last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate