learning_rate.py 14.5 KB
Newer Older
S
shippingwang 已提交
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
S
shippingwang 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
S
shippingwang 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

D
dongshuilong 已提交
15 16 17
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

H
HydrogenSulfate 已提交
18 19
from abc import abstractmethod
from typing import Union
W
WuHaobo 已提交
20

H
HydrogenSulfate 已提交
21
from paddle.optimizer import lr
22 23
from ppcls.utils import logger

W
WuHaobo 已提交
24

H
HydrogenSulfate 已提交
25 26 27 28 29 30 31 32 33 34 35 36
class LRBase(object):
    """Base class for custom learning rates

    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
        verbose (bool): If True, prints a message to stdout for each update. Defaults to False
W
WuHaobo 已提交
37
    """
H
HydrogenSulfate 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

    def __init__(self,
                 epochs: int,
                 step_each_epoch: int,
                 learning_rate: float,
                 warmup_epoch: int,
                 warmup_start_lr: float,
                 last_epoch: int,
                 by_epoch: bool,
                 verbose: bool=False) -> None:
        """Initialize and record the necessary parameters
        """
        super(LRBase, self).__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
        self.epochs = epochs
        self.step_each_epoch = step_each_epoch
        self.learning_rate = learning_rate
        self.warmup_epoch = warmup_epoch
59 60
        self.warmup_steps = self.warmup_epoch if by_epoch else round(
            self.warmup_epoch * self.step_each_epoch)
H
HydrogenSulfate 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        self.warmup_start_lr = warmup_start_lr
        self.last_epoch = last_epoch
        self.by_epoch = by_epoch
        self.verbose = verbose

    @abstractmethod
    def __call__(self, *kargs, **kwargs) -> lr.LRScheduler:
        """generate an learning rate scheduler

        Returns:
            lr.LinearWarmup: learning rate scheduler
        """
        pass

    def linear_warmup(
            self,
            learning_rate: Union[float, lr.LRScheduler]) -> lr.LinearWarmup:
        """Add an Linear Warmup before learning_rate

        Args:
            learning_rate (Union[float, lr.LRScheduler]): original learning rate without warmup

        Returns:
            lr.LinearWarmup: learning rate scheduler with warmup
        """
        warmup_lr = lr.LinearWarmup(
            learning_rate=learning_rate,
            warmup_steps=self.warmup_steps,
            start_lr=self.warmup_start_lr,
            end_lr=self.learning_rate,
            last_epoch=self.last_epoch,
            verbose=self.verbose)
        return warmup_lr


H
HydrogenSulfate 已提交
96
class Constant(lr.LRScheduler):
H
HydrogenSulfate 已提交
97 98 99 100 101 102 103 104 105 106
    """Constant learning rate Class implementation

    Args:
        learning_rate (float): The initial learning rate
        last_epoch (int, optional): The index of last epoch. Default: -1.
    """

    def __init__(self, learning_rate, last_epoch=-1, **kwargs):
        self.learning_rate = learning_rate
        self.last_epoch = last_epoch
H
HydrogenSulfate 已提交
107
        super(Constant, self).__init__()
H
HydrogenSulfate 已提交
108 109 110 111 112 113 114

    def get_lr(self) -> float:
        """always return the same learning rate
        """
        return self.learning_rate


H
HydrogenSulfate 已提交
115
class ConstLR(LRBase):
H
HydrogenSulfate 已提交
116 117
    """Constant learning rate

W
WuHaobo 已提交
118
    Args:
H
HydrogenSulfate 已提交
119 120 121 122 123 124 125
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
W
WuHaobo 已提交
126 127
    """

L
littletomatodonkey 已提交
128
    def __init__(self,
H
HydrogenSulfate 已提交
129 130
                 epochs,
                 step_each_epoch,
L
littletomatodonkey 已提交
131
                 learning_rate,
H
HydrogenSulfate 已提交
132 133 134 135 136
                 warmup_epoch=0,
                 warmup_start_lr=0.0,
                 last_epoch=-1,
                 by_epoch=False,
                 **kwargs):
H
HydrogenSulfate 已提交
137 138 139
        super(ConstLR, self).__init__(epochs, step_each_epoch, learning_rate,
                                      warmup_epoch, warmup_start_lr,
                                      last_epoch, by_epoch)
H
HydrogenSulfate 已提交
140 141

    def __call__(self):
H
HydrogenSulfate 已提交
142
        learning_rate = Constant(
H
HydrogenSulfate 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
            learning_rate=self.learning_rate, last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate


class Linear(LRBase):
    """Linear learning rate decay

    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        end_lr (float, optional): The minimum final learning rate. Defaults to 0.0.
        power (float, optional): Power of polynomial. Defaults to 1.0.
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
    """

    def __init__(self,
L
littletomatodonkey 已提交
168 169
                 epochs,
                 step_each_epoch,
H
HydrogenSulfate 已提交
170
                 learning_rate,
L
littletomatodonkey 已提交
171 172
                 end_lr=0.0,
                 power=1.0,
H
HydrogenSulfate 已提交
173
                 cycle=False,
L
littletomatodonkey 已提交
174
                 warmup_epoch=0,
175
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
176
                 last_epoch=-1,
H
HydrogenSulfate 已提交
177
                 by_epoch=False,
L
littletomatodonkey 已提交
178
                 **kwargs):
H
HydrogenSulfate 已提交
179 180 181 182
        super(Linear, self).__init__(epochs, step_each_epoch, learning_rate,
                                     warmup_epoch, warmup_start_lr, last_epoch,
                                     by_epoch)
        self.decay_steps = (epochs - self.warmup_epoch) * step_each_epoch
L
littletomatodonkey 已提交
183 184
        self.end_lr = end_lr
        self.power = power
H
HydrogenSulfate 已提交
185 186 187 188
        self.cycle = cycle
        self.warmup_steps = round(self.warmup_epoch * step_each_epoch)
        if self.by_epoch:
            self.decay_steps = self.epochs - self.warmup_epoch
W
WuHaobo 已提交
189

L
littletomatodonkey 已提交
190 191 192
    def __call__(self):
        learning_rate = lr.PolynomialDecay(
            learning_rate=self.learning_rate,
H
HydrogenSulfate 已提交
193
            decay_steps=self.decay_steps,
L
littletomatodonkey 已提交
194 195
            end_lr=self.end_lr,
            power=self.power,
H
HydrogenSulfate 已提交
196
            cycle=self.cycle,
197
            last_epoch=self.
H
HydrogenSulfate 已提交
198
            last_epoch) if self.decay_steps > 0 else self.learning_rate
L
littletomatodonkey 已提交
199

H
HydrogenSulfate 已提交
200 201
        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)
L
littletomatodonkey 已提交
202

H
HydrogenSulfate 已提交
203 204
        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate
205 206


H
HydrogenSulfate 已提交
207 208
class Cosine(LRBase):
    """Cosine learning rate decay
209

H
HydrogenSulfate 已提交
210
    ``lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)``
211

W
WuHaobo 已提交
212
    Args:
H
HydrogenSulfate 已提交
213 214 215 216 217 218 219 220
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
W
WuHaobo 已提交
221 222
    """

L
littletomatodonkey 已提交
223 224
    def __init__(self,
                 epochs,
H
HydrogenSulfate 已提交
225 226
                 step_each_epoch,
                 learning_rate,
227
                 eta_min=0.0,
L
littletomatodonkey 已提交
228
                 warmup_epoch=0,
229
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
230
                 last_epoch=-1,
H
HydrogenSulfate 已提交
231
                 by_epoch=False,
L
littletomatodonkey 已提交
232
                 **kwargs):
H
HydrogenSulfate 已提交
233
        super(Cosine, self).__init__(epochs, step_each_epoch, learning_rate,
H
HydrogenSulfate 已提交
234 235 236
                                     warmup_epoch, warmup_start_lr, last_epoch,
                                     by_epoch)
        self.T_max = (self.epochs - self.warmup_epoch) * self.step_each_epoch
237
        self.eta_min = eta_min
H
HydrogenSulfate 已提交
238 239
        if self.by_epoch:
            self.T_max = self.epochs - self.warmup_epoch
W
WuHaobo 已提交
240

L
littletomatodonkey 已提交
241 242 243 244
    def __call__(self):
        learning_rate = lr.CosineAnnealingDecay(
            learning_rate=self.learning_rate,
            T_max=self.T_max,
245
            eta_min=self.eta_min,
246 247
            last_epoch=self.
            last_epoch) if self.T_max > 0 else self.learning_rate
H
HydrogenSulfate 已提交
248

249
        if self.warmup_steps > 0:
H
HydrogenSulfate 已提交
250 251 252
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
253 254 255
        return learning_rate


H
HydrogenSulfate 已提交
256 257 258
class Step(LRBase):
    """Step learning rate decay

S
shippingwang 已提交
259
    Args:
H
HydrogenSulfate 已提交
260 261 262
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
L
littletomatodonkey 已提交
263
        step_size (int): the interval to update.
H
HydrogenSulfate 已提交
264 265 266 267 268
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
S
shippingwang 已提交
269 270
    """

S
shippingwang 已提交
271
    def __init__(self,
H
HydrogenSulfate 已提交
272 273
                 epochs,
                 step_each_epoch,
L
littletomatodonkey 已提交
274 275 276 277
                 learning_rate,
                 step_size,
                 gamma,
                 warmup_epoch=0,
278
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
279
                 last_epoch=-1,
H
HydrogenSulfate 已提交
280
                 by_epoch=False,
S
shippingwang 已提交
281
                 **kwargs):
H
HydrogenSulfate 已提交
282 283 284
        super(Step, self).__init__(epochs, step_each_epoch, learning_rate,
                                   warmup_epoch, warmup_start_lr, last_epoch,
                                   by_epoch)
H
HydrogenSulfate 已提交
285
        self.step_size = step_size * step_each_epoch
L
littletomatodonkey 已提交
286
        self.gamma = gamma
H
HydrogenSulfate 已提交
287 288
        if self.by_epoch:
            self.step_size = step_size
S
shippingwang 已提交
289

L
littletomatodonkey 已提交
290 291 292 293 294 295
    def __call__(self):
        learning_rate = lr.StepDecay(
            learning_rate=self.learning_rate,
            step_size=self.step_size,
            gamma=self.gamma,
            last_epoch=self.last_epoch)
H
HydrogenSulfate 已提交
296

297
        if self.warmup_steps > 0:
H
HydrogenSulfate 已提交
298 299 300
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
301 302 303
        return learning_rate


H
HydrogenSulfate 已提交
304 305 306
class Piecewise(LRBase):
    """Piecewise learning rate decay

W
WuHaobo 已提交
307
    Args:
H
HydrogenSulfate 已提交
308 309 310 311 312 313 314 315
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        decay_epochs (List[int]): A list of steps numbers. The type of element in the list is python int.
        values (List[float]): A list of learning rate values that will be picked during different epoch boundaries.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
W
WuHaobo 已提交
316 317 318
    """

    def __init__(self,
H
HydrogenSulfate 已提交
319
                 epochs,
L
littletomatodonkey 已提交
320 321 322 323
                 step_each_epoch,
                 decay_epochs,
                 values,
                 warmup_epoch=0,
324
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
325
                 last_epoch=-1,
H
HydrogenSulfate 已提交
326
                 by_epoch=False,
L
littletomatodonkey 已提交
327
                 **kwargs):
H
HydrogenSulfate 已提交
328
        super(Piecewise,
H
HydrogenSulfate 已提交
329 330
              self).__init__(epochs, step_each_epoch, values[0], warmup_epoch,
                             warmup_start_lr, last_epoch, by_epoch)
L
littletomatodonkey 已提交
331
        self.values = values
H
HydrogenSulfate 已提交
332 333 334
        self.boundaries_steps = [e * step_each_epoch for e in decay_epochs]
        if self.by_epoch is True:
            self.boundaries_steps = decay_epochs
W
WuHaobo 已提交
335 336

    def __call__(self):
H
HydrogenSulfate 已提交
337 338 339 340 341 342 343 344
        learning_rate = lr.PiecewiseDecay(
            boundaries=self.boundaries_steps,
            values=self.values,
            last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

H
HydrogenSulfate 已提交
345
        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
346
        return learning_rate
D
dongshuilong 已提交
347 348


H
HydrogenSulfate 已提交
349 350 351
class MultiStepDecay(LRBase):
    """MultiStepDecay learning rate decay

D
dongshuilong 已提交
352
    Args:
H
HydrogenSulfate 已提交
353 354 355 356 357 358 359 360 361
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        milestones (List[int]): List of each boundaries. Must be increasing.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Defaults to 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
D
dongshuilong 已提交
362 363 364 365 366
    """

    def __init__(self,
                 epochs,
                 step_each_epoch,
H
HydrogenSulfate 已提交
367 368
                 learning_rate,
                 milestones,
D
dongshuilong 已提交
369
                 gamma=0.1,
H
HydrogenSulfate 已提交
370 371
                 warmup_epoch=0,
                 warmup_start_lr=0.0,
D
dongshuilong 已提交
372
                 last_epoch=-1,
H
HydrogenSulfate 已提交
373 374
                 by_epoch=False,
                 **kwargs):
H
HydrogenSulfate 已提交
375 376 377
        super(MultiStepDecay, self).__init__(
            epochs, step_each_epoch, learning_rate, warmup_epoch,
            warmup_start_lr, last_epoch, by_epoch)
D
dongshuilong 已提交
378 379
        self.milestones = [x * step_each_epoch for x in milestones]
        self.gamma = gamma
H
HydrogenSulfate 已提交
380 381
        if self.by_epoch:
            self.milestones = milestones
D
dongshuilong 已提交
382

H
HydrogenSulfate 已提交
383 384 385 386 387 388 389 390 391 392 393 394
    def __call__(self):
        learning_rate = lr.MultiStepDecay(
            learning_rate=self.learning_rate,
            milestones=self.milestones,
            gamma=self.gamma,
            last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate