learning_rate.py 24.3 KB
Newer Older
S
shippingwang 已提交
1
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2
#
S
shippingwang 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
S
shippingwang 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

D
dongshuilong 已提交
15 16
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)
D
dongshuilong 已提交
17
import math
H
HydrogenSulfate 已提交
18
import types
H
HydrogenSulfate 已提交
19 20 21
from abc import abstractmethod
from typing import Union
from paddle.optimizer import lr
22 23
from ppcls.utils import logger

W
WuHaobo 已提交
24

H
HydrogenSulfate 已提交
25 26 27 28 29 30 31 32 33 34 35 36
class LRBase(object):
    """Base class for custom learning rates

    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
        verbose (bool): If True, prints a message to stdout for each update. Defaults to False
W
WuHaobo 已提交
37
    """
H
HydrogenSulfate 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

    def __init__(self,
                 epochs: int,
                 step_each_epoch: int,
                 learning_rate: float,
                 warmup_epoch: int,
                 warmup_start_lr: float,
                 last_epoch: int,
                 by_epoch: bool,
                 verbose: bool=False) -> None:
        """Initialize and record the necessary parameters
        """
        super(LRBase, self).__init__()
        if warmup_epoch >= epochs:
            msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
            logger.warning(msg)
            warmup_epoch = epochs
        self.epochs = epochs
        self.step_each_epoch = step_each_epoch
        self.learning_rate = learning_rate
        self.warmup_epoch = warmup_epoch
59 60
        self.warmup_steps = self.warmup_epoch if by_epoch else round(
            self.warmup_epoch * self.step_each_epoch)
H
HydrogenSulfate 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        self.warmup_start_lr = warmup_start_lr
        self.last_epoch = last_epoch
        self.by_epoch = by_epoch
        self.verbose = verbose

    @abstractmethod
    def __call__(self, *kargs, **kwargs) -> lr.LRScheduler:
        """generate an learning rate scheduler

        Returns:
            lr.LinearWarmup: learning rate scheduler
        """
        pass

    def linear_warmup(
            self,
            learning_rate: Union[float, lr.LRScheduler]) -> lr.LinearWarmup:
        """Add an Linear Warmup before learning_rate

        Args:
            learning_rate (Union[float, lr.LRScheduler]): original learning rate without warmup

        Returns:
            lr.LinearWarmup: learning rate scheduler with warmup
        """
        warmup_lr = lr.LinearWarmup(
            learning_rate=learning_rate,
            warmup_steps=self.warmup_steps,
            start_lr=self.warmup_start_lr,
            end_lr=self.learning_rate,
            last_epoch=self.last_epoch,
            verbose=self.verbose)
        return warmup_lr


H
HydrogenSulfate 已提交
96
class Constant(lr.LRScheduler):
H
HydrogenSulfate 已提交
97 98 99 100 101 102 103 104 105 106
    """Constant learning rate Class implementation

    Args:
        learning_rate (float): The initial learning rate
        last_epoch (int, optional): The index of last epoch. Default: -1.
    """

    def __init__(self, learning_rate, last_epoch=-1, **kwargs):
        self.learning_rate = learning_rate
        self.last_epoch = last_epoch
H
HydrogenSulfate 已提交
107
        super(Constant, self).__init__()
H
HydrogenSulfate 已提交
108 109 110 111 112 113 114

    def get_lr(self) -> float:
        """always return the same learning rate
        """
        return self.learning_rate


H
HydrogenSulfate 已提交
115
class ConstLR(LRBase):
H
HydrogenSulfate 已提交
116 117
    """Constant learning rate

W
WuHaobo 已提交
118
    Args:
H
HydrogenSulfate 已提交
119 120 121 122 123 124 125
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
W
WuHaobo 已提交
126 127
    """

L
littletomatodonkey 已提交
128
    def __init__(self,
H
HydrogenSulfate 已提交
129 130
                 epochs,
                 step_each_epoch,
L
littletomatodonkey 已提交
131
                 learning_rate,
H
HydrogenSulfate 已提交
132 133 134 135 136
                 warmup_epoch=0,
                 warmup_start_lr=0.0,
                 last_epoch=-1,
                 by_epoch=False,
                 **kwargs):
H
HydrogenSulfate 已提交
137 138 139
        super(ConstLR, self).__init__(epochs, step_each_epoch, learning_rate,
                                      warmup_epoch, warmup_start_lr,
                                      last_epoch, by_epoch)
H
HydrogenSulfate 已提交
140 141

    def __call__(self):
H
HydrogenSulfate 已提交
142
        learning_rate = Constant(
H
HydrogenSulfate 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
            learning_rate=self.learning_rate, last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate


class Linear(LRBase):
    """Linear learning rate decay

    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        end_lr (float, optional): The minimum final learning rate. Defaults to 0.0.
        power (float, optional): Power of polynomial. Defaults to 1.0.
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        last_epoch (int): last epoch
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
    """

    def __init__(self,
L
littletomatodonkey 已提交
168 169
                 epochs,
                 step_each_epoch,
H
HydrogenSulfate 已提交
170
                 learning_rate,
L
littletomatodonkey 已提交
171 172
                 end_lr=0.0,
                 power=1.0,
H
HydrogenSulfate 已提交
173
                 cycle=False,
L
littletomatodonkey 已提交
174
                 warmup_epoch=0,
175
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
176
                 last_epoch=-1,
H
HydrogenSulfate 已提交
177
                 by_epoch=False,
L
littletomatodonkey 已提交
178
                 **kwargs):
H
HydrogenSulfate 已提交
179 180 181 182
        super(Linear, self).__init__(epochs, step_each_epoch, learning_rate,
                                     warmup_epoch, warmup_start_lr, last_epoch,
                                     by_epoch)
        self.decay_steps = (epochs - self.warmup_epoch) * step_each_epoch
L
littletomatodonkey 已提交
183 184
        self.end_lr = end_lr
        self.power = power
H
HydrogenSulfate 已提交
185 186 187 188
        self.cycle = cycle
        self.warmup_steps = round(self.warmup_epoch * step_each_epoch)
        if self.by_epoch:
            self.decay_steps = self.epochs - self.warmup_epoch
W
WuHaobo 已提交
189

L
littletomatodonkey 已提交
190 191 192
    def __call__(self):
        learning_rate = lr.PolynomialDecay(
            learning_rate=self.learning_rate,
H
HydrogenSulfate 已提交
193
            decay_steps=self.decay_steps,
L
littletomatodonkey 已提交
194 195
            end_lr=self.end_lr,
            power=self.power,
H
HydrogenSulfate 已提交
196
            cycle=self.cycle,
197 198
            last_epoch=self.last_epoch) if self.decay_steps > 0 else Constant(
                self.learning_rate)
L
littletomatodonkey 已提交
199

H
HydrogenSulfate 已提交
200 201
        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)
L
littletomatodonkey 已提交
202

H
HydrogenSulfate 已提交
203 204
        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate
205 206


H
HydrogenSulfate 已提交
207 208
class Cosine(LRBase):
    """Cosine learning rate decay
209

H
HydrogenSulfate 已提交
210
    ``lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)``
211

W
WuHaobo 已提交
212
    Args:
H
HydrogenSulfate 已提交
213 214 215 216 217 218 219 220
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        eta_min (float, optional): Minimum learning rate. Defaults to 0.0.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
W
WuHaobo 已提交
221 222
    """

L
littletomatodonkey 已提交
223 224
    def __init__(self,
                 epochs,
H
HydrogenSulfate 已提交
225 226
                 step_each_epoch,
                 learning_rate,
227
                 eta_min=0.0,
L
littletomatodonkey 已提交
228
                 warmup_epoch=0,
229
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
230
                 last_epoch=-1,
H
HydrogenSulfate 已提交
231
                 by_epoch=False,
L
littletomatodonkey 已提交
232
                 **kwargs):
H
HydrogenSulfate 已提交
233
        super(Cosine, self).__init__(epochs, step_each_epoch, learning_rate,
H
HydrogenSulfate 已提交
234 235 236
                                     warmup_epoch, warmup_start_lr, last_epoch,
                                     by_epoch)
        self.T_max = (self.epochs - self.warmup_epoch) * self.step_each_epoch
237
        self.eta_min = eta_min
H
HydrogenSulfate 已提交
238 239
        if self.by_epoch:
            self.T_max = self.epochs - self.warmup_epoch
W
WuHaobo 已提交
240

L
littletomatodonkey 已提交
241 242 243 244
    def __call__(self):
        learning_rate = lr.CosineAnnealingDecay(
            learning_rate=self.learning_rate,
            T_max=self.T_max,
245
            eta_min=self.eta_min,
246 247
            last_epoch=self.last_epoch) if self.T_max > 0 else Constant(
                self.learning_rate)
H
HydrogenSulfate 已提交
248

249
        if self.warmup_steps > 0:
H
HydrogenSulfate 已提交
250 251 252
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
253 254 255
        return learning_rate


T
tianyi1997 已提交
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
class Cyclic(LRBase):
    """Cyclic learning rate decay
    
    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
            that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
        max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
            Since there is some scaling operation during process of learning rate adjustment,
            max_learning_rate may not actually be reached.
        warmup_epoch (int): number of warmup epoch(s)
        warmup_start_lr (float): start learning rate within warmup
        step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
            The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
            size should be set as at least 3 or 4 times steps in one epoch.
        step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
            If not specified, it's value will initialize to `` step_size_up `` . Default: None
        mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
            If scale_fn is specified, this argument will be ignored. Default: 'triangular'
        exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
        scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
            It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
            If specified, then 'mode' will be ignored. Default: None
        scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
            number or cycle iterations (total iterations since start of training). Default: 'cycle'
        last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
        by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
        verbose: (bool, optional): If True, prints a message to stdout for each update. Defaults to False
    """

    def __init__(self,
                 epochs,
                 step_each_epoch,
                 base_learning_rate,
                 max_learning_rate,
                 warmup_epoch,
                 warmup_start_lr,
                 step_size_up,
                 step_size_down=None,
                 mode='triangular',
                 exp_gamma=1.0,
                 scale_fn=None,
                 scale_mode='cycle',
                 by_epoch=False,
                 last_epoch=-1,
                 verbose=False):

        super(Cyclic, self).__init__(
            epochs, step_each_epoch, base_learning_rate, warmup_epoch,
            warmup_start_lr, last_epoch, by_epoch, verbose)
        self.base_learning_rate = base_learning_rate
        self.max_learning_rate = max_learning_rate
        self.step_size_up = step_size_up
        self.step_size_down = step_size_down
        self.mode = mode
        self.exp_gamma = exp_gamma
        self.scale_fn = scale_fn
        self.scale_mode = scale_mode

    def __call__(self):
        learning_rate = lr.CyclicLR(
            base_learning_rate=self.base_learning_rate,
            max_learning_rate=self.max_learning_rate,
            step_size_up=self.step_size_up,
            step_size_down=self.step_size_down,
            mode=self.mode,
            exp_gamma=self.exp_gamma,
            scale_fn=self.scale_fn,
            scale_mode=self.scale_mode,
            last_epoch=self.last_epoch,
            verbose=self.verbose)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate


H
HydrogenSulfate 已提交
336 337 338
class Step(LRBase):
    """Step learning rate decay

S
shippingwang 已提交
339
    Args:
H
HydrogenSulfate 已提交
340 341 342
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
L
littletomatodonkey 已提交
343
        step_size (int): the interval to update.
H
HydrogenSulfate 已提交
344 345 346 347 348
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
S
shippingwang 已提交
349 350
    """

S
shippingwang 已提交
351
    def __init__(self,
H
HydrogenSulfate 已提交
352 353
                 epochs,
                 step_each_epoch,
L
littletomatodonkey 已提交
354 355 356 357
                 learning_rate,
                 step_size,
                 gamma,
                 warmup_epoch=0,
358
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
359
                 last_epoch=-1,
H
HydrogenSulfate 已提交
360
                 by_epoch=False,
S
shippingwang 已提交
361
                 **kwargs):
H
HydrogenSulfate 已提交
362 363 364
        super(Step, self).__init__(epochs, step_each_epoch, learning_rate,
                                   warmup_epoch, warmup_start_lr, last_epoch,
                                   by_epoch)
H
HydrogenSulfate 已提交
365
        self.step_size = step_size * step_each_epoch
L
littletomatodonkey 已提交
366
        self.gamma = gamma
H
HydrogenSulfate 已提交
367 368
        if self.by_epoch:
            self.step_size = step_size
S
shippingwang 已提交
369

L
littletomatodonkey 已提交
370 371 372 373 374 375
    def __call__(self):
        learning_rate = lr.StepDecay(
            learning_rate=self.learning_rate,
            step_size=self.step_size,
            gamma=self.gamma,
            last_epoch=self.last_epoch)
H
HydrogenSulfate 已提交
376

377
        if self.warmup_steps > 0:
H
HydrogenSulfate 已提交
378 379 380
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
381 382 383
        return learning_rate


H
HydrogenSulfate 已提交
384 385 386
class Piecewise(LRBase):
    """Piecewise learning rate decay

W
WuHaobo 已提交
387
    Args:
H
HydrogenSulfate 已提交
388 389 390 391 392 393 394 395
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        decay_epochs (List[int]): A list of steps numbers. The type of element in the list is python int.
        values (List[float]): A list of learning rate values that will be picked during different epoch boundaries.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
W
WuHaobo 已提交
396 397 398
    """

    def __init__(self,
H
HydrogenSulfate 已提交
399
                 epochs,
L
littletomatodonkey 已提交
400 401 402 403
                 step_each_epoch,
                 decay_epochs,
                 values,
                 warmup_epoch=0,
404
                 warmup_start_lr=0.0,
L
littletomatodonkey 已提交
405
                 last_epoch=-1,
H
HydrogenSulfate 已提交
406
                 by_epoch=False,
L
littletomatodonkey 已提交
407
                 **kwargs):
H
HydrogenSulfate 已提交
408
        super(Piecewise,
H
HydrogenSulfate 已提交
409 410
              self).__init__(epochs, step_each_epoch, values[0], warmup_epoch,
                             warmup_start_lr, last_epoch, by_epoch)
L
littletomatodonkey 已提交
411
        self.values = values
H
HydrogenSulfate 已提交
412 413 414
        self.boundaries_steps = [e * step_each_epoch for e in decay_epochs]
        if self.by_epoch is True:
            self.boundaries_steps = decay_epochs
W
WuHaobo 已提交
415 416

    def __call__(self):
H
HydrogenSulfate 已提交
417 418 419 420 421 422 423 424
        learning_rate = lr.PiecewiseDecay(
            boundaries=self.boundaries_steps,
            values=self.values,
            last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

H
HydrogenSulfate 已提交
425
        setattr(learning_rate, "by_epoch", self.by_epoch)
L
littletomatodonkey 已提交
426
        return learning_rate
D
dongshuilong 已提交
427 428


H
HydrogenSulfate 已提交
429 430 431
class MultiStepDecay(LRBase):
    """MultiStepDecay learning rate decay

D
dongshuilong 已提交
432
    Args:
H
HydrogenSulfate 已提交
433 434 435 436 437 438 439 440 441
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        milestones (List[int]): List of each boundaries. Must be increasing.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Defaults to 0.1.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
D
dongshuilong 已提交
442 443 444 445 446
    """

    def __init__(self,
                 epochs,
                 step_each_epoch,
H
HydrogenSulfate 已提交
447 448
                 learning_rate,
                 milestones,
D
dongshuilong 已提交
449
                 gamma=0.1,
H
HydrogenSulfate 已提交
450 451
                 warmup_epoch=0,
                 warmup_start_lr=0.0,
D
dongshuilong 已提交
452
                 last_epoch=-1,
H
HydrogenSulfate 已提交
453 454
                 by_epoch=False,
                 **kwargs):
H
HydrogenSulfate 已提交
455 456 457
        super(MultiStepDecay, self).__init__(
            epochs, step_each_epoch, learning_rate, warmup_epoch,
            warmup_start_lr, last_epoch, by_epoch)
D
dongshuilong 已提交
458 459
        self.milestones = [x * step_each_epoch for x in milestones]
        self.gamma = gamma
H
HydrogenSulfate 已提交
460 461
        if self.by_epoch:
            self.milestones = milestones
D
dongshuilong 已提交
462

H
HydrogenSulfate 已提交
463 464 465 466 467 468 469 470 471 472 473 474
    def __call__(self):
        learning_rate = lr.MultiStepDecay(
            learning_rate=self.learning_rate,
            milestones=self.milestones,
            gamma=self.gamma,
            last_epoch=self.last_epoch)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate
H
add xbm  
HydrogenSulfate 已提交
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503


class ReduceOnPlateau(LRBase):
    """ReduceOnPlateau learning rate decay
    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'``, the learning
            rate will reduce when ``loss`` stops ascending. Defaults to ``'min'``.
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
            It should be less than 1.0. Defaults to 0.1.
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
            Defaults to 10.
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
            This make tiny changes of ``loss`` will be ignored. Defaults to 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
            change of ``loss`` is ``threshold`` . Defaults to ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Defaults to 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Defaults to 0.
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
            the update is ignored. Defaults to 1e-8.
        warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
        warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
    """
T
tianyi1997 已提交
504

H
add xbm  
HydrogenSulfate 已提交
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548
    def __init__(self,
                 epochs,
                 step_each_epoch,
                 learning_rate,
                 mode='min',
                 factor=0.1,
                 patience=10,
                 threshold=1e-4,
                 threshold_mode='rel',
                 cooldown=0,
                 min_lr=0,
                 epsilon=1e-8,
                 warmup_epoch=0,
                 warmup_start_lr=0.0,
                 last_epoch=-1,
                 by_epoch=False,
                 **kwargs):
        super(ReduceOnPlateau, self).__init__(
            epochs, step_each_epoch, learning_rate, warmup_epoch,
            warmup_start_lr, last_epoch, by_epoch)
        self.mode = mode
        self.factor = factor
        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

    def __call__(self):
        learning_rate = lr.ReduceOnPlateau(
            learning_rate=self.learning_rate,
            mode=self.mode,
            factor=self.factor,
            patience=self.patience,
            threshold=self.threshold,
            threshold_mode=self.threshold_mode,
            cooldown=self.cooldown,
            min_lr=self.min_lr,
            epsilon=self.epsilon)

        if self.warmup_steps > 0:
            learning_rate = self.linear_warmup(learning_rate)

H
HydrogenSulfate 已提交
549 550 551 552 553 554 555
        # NOTE: Implement get_lr() method for class `ReduceOnPlateau`,
        # which is called in `log_info` function
        def get_lr(self):
            return self.last_lr

        learning_rate.get_lr = types.MethodType(get_lr, learning_rate)

H
add xbm  
HydrogenSulfate 已提交
556 557
        setattr(learning_rate, "by_epoch", self.by_epoch)
        return learning_rate
D
dongshuilong 已提交
558 559 560 561 562 563 564 565 566 567 568 569 570 571


class CosineFixmatch(LRBase):
    """Cosine decay in FixMatch style

    Args:
        epochs (int): total epoch(s)
        step_each_epoch (int): number of iterations within an epoch
        learning_rate (float): learning rate
        num_warmup_steps (int): the number warmup steps.
        warmunum_cycles (float, optional): the factor for cosine in FixMatch learning rate. Defaults to 7 / 16.
        last_epoch (int, optional): last epoch. Defaults to -1.
        by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
    """
T
tianyi1997 已提交
572

D
dongshuilong 已提交
573 574 575 576 577 578 579 580 581 582 583 584 585 586
    def __init__(self,
                 epochs,
                 step_each_epoch,
                 learning_rate,
                 num_warmup_steps,
                 num_cycles=7 / 16,
                 last_epoch=-1,
                 by_epoch=False):
        self.epochs = epochs
        self.step_each_epoch = step_each_epoch
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_cycles = num_cycles
        self.last_epoch = last_epoch
D
dongshuilong 已提交
587
        self.by_epoch = by_epoch
D
dongshuilong 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602

    def __call__(self):
        def _lr_lambda(current_step):
            if current_step < self.num_warmup_steps:
                return float(current_step) / float(
                    max(1, self.num_warmup_steps))
            no_progress = float(current_step - self.num_warmup_steps) / \
                        float(max(1, self.epochs * self.step_each_epoch - self.num_warmup_steps))
            return max(0., math.cos(math.pi * self.num_cycles * no_progress))

        learning_rate = lr.LambdaDecay(
            learning_rate=self.learning_rate,
            lr_lambda=_lr_lambda,
            last_epoch=self.last_epoch)
        setattr(learning_rate, "by_epoch", self.by_epoch)
D
dongshuilong 已提交
603
        return learning_rate