learning_rate.py 14.0 KB
Newer Older
W
WenmuZhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

W
WenmuZhou 已提交
20
from paddle.optimizer import lr
B
bupt906 已提交
21
from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
W
WenmuZhou 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35


class Linear(object):
    """
    Linear learning rate decay
    Args:
        lr (float): The initial learning rate. It is a python float number.
        epochs(int): The decay step size. It determines the decay cycle.
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
        power(float, optional): Power of polynomial. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
W
WenmuZhou 已提交
36
                 learning_rate,
W
WenmuZhou 已提交
37 38 39 40 41 42 43 44
                 epochs,
                 step_each_epoch,
                 end_lr=0.0,
                 power=1.0,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(Linear, self).__init__()
W
WenmuZhou 已提交
45
        self.learning_rate = learning_rate
W
WenmuZhou 已提交
46 47 48 49
        self.epochs = epochs * step_each_epoch
        self.end_lr = end_lr
        self.power = power
        self.last_epoch = last_epoch
50
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)
W
WenmuZhou 已提交
51 52

    def __call__(self):
W
WenmuZhou 已提交
53 54
        learning_rate = lr.PolynomialDecay(
            learning_rate=self.learning_rate,
W
WenmuZhou 已提交
55 56 57 58 59
            decay_steps=self.epochs,
            end_lr=self.end_lr,
            power=self.power,
            last_epoch=self.last_epoch)
        if self.warmup_epoch > 0:
W
WenmuZhou 已提交
60
            learning_rate = lr.LinearWarmup(
W
WenmuZhou 已提交
61 62 63
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
W
WenmuZhou 已提交
64
                end_lr=self.learning_rate,
W
WenmuZhou 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
                last_epoch=self.last_epoch)
        return learning_rate


class Cosine(object):
    """
    Cosine learning rate decay
    lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
    Args:
        lr(float): initial learning rate
        step_each_epoch(int): steps each epoch
        epochs(int): total training epochs
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
W
WenmuZhou 已提交
81
                 learning_rate,
W
WenmuZhou 已提交
82 83 84 85 86 87
                 step_each_epoch,
                 epochs,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(Cosine, self).__init__()
W
WenmuZhou 已提交
88
        self.learning_rate = learning_rate
W
WenmuZhou 已提交
89 90
        self.T_max = step_each_epoch * epochs
        self.last_epoch = last_epoch
91
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)
W
WenmuZhou 已提交
92 93

    def __call__(self):
W
WenmuZhou 已提交
94 95 96 97
        learning_rate = lr.CosineAnnealingDecay(
            learning_rate=self.learning_rate,
            T_max=self.T_max,
            last_epoch=self.last_epoch)
W
WenmuZhou 已提交
98
        if self.warmup_epoch > 0:
W
WenmuZhou 已提交
99
            learning_rate = lr.LinearWarmup(
W
WenmuZhou 已提交
100 101 102
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
W
WenmuZhou 已提交
103
                end_lr=self.learning_rate,
W
WenmuZhou 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
                last_epoch=self.last_epoch)
        return learning_rate


class Step(object):
    """
    Piecewise learning rate decay
    Args:
        step_each_epoch(int): steps each epoch
        learning_rate (float): The initial learning rate. It is a python float number.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
W
WenmuZhou 已提交
121
                 learning_rate,
W
WenmuZhou 已提交
122 123 124 125 126 127 128 129
                 step_size,
                 step_each_epoch,
                 gamma,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(Step, self).__init__()
        self.step_size = step_each_epoch * step_size
W
WenmuZhou 已提交
130
        self.learning_rate = learning_rate
W
WenmuZhou 已提交
131 132
        self.gamma = gamma
        self.last_epoch = last_epoch
133
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)
W
WenmuZhou 已提交
134 135

    def __call__(self):
W
WenmuZhou 已提交
136 137
        learning_rate = lr.StepDecay(
            learning_rate=self.learning_rate,
W
WenmuZhou 已提交
138 139 140 141
            step_size=self.step_size,
            gamma=self.gamma,
            last_epoch=self.last_epoch)
        if self.warmup_epoch > 0:
W
WenmuZhou 已提交
142
            learning_rate = lr.LinearWarmup(
W
WenmuZhou 已提交
143 144 145
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
W
WenmuZhou 已提交
146
                end_lr=self.learning_rate,
W
WenmuZhou 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
                last_epoch=self.last_epoch)
        return learning_rate


class Piecewise(object):
    """
    Piecewise learning rate decay
    Args:
        boundaries(list): A list of steps numbers. The type of element in the list is python int.
        values(list): A list of learning rate values that will be picked during different epoch boundaries.
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
                 step_each_epoch,
                 decay_epochs,
                 values,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(Piecewise, self).__init__()
        self.boundaries = [step_each_epoch * e for e in decay_epochs]
        self.values = values
        self.last_epoch = last_epoch
172
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)
W
WenmuZhou 已提交
173 174

    def __call__(self):
W
WenmuZhou 已提交
175
        learning_rate = lr.PiecewiseDecay(
W
WenmuZhou 已提交
176 177 178 179
            boundaries=self.boundaries,
            values=self.values,
            last_epoch=self.last_epoch)
        if self.warmup_epoch > 0:
W
WenmuZhou 已提交
180
            learning_rate = lr.LinearWarmup(
W
WenmuZhou 已提交
181 182 183 184 185 186
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
                end_lr=self.values[0],
                last_epoch=self.last_epoch)
        return learning_rate
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228


class CyclicalCosine(object):
    """
    Cyclical cosine learning rate decay
    Args:
        learning_rate(float): initial learning rate
        step_each_epoch(int): steps each epoch
        epochs(int): total training epochs
        cycle(int): period of the cosine learning rate
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
                 cycle,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(CyclicalCosine, self).__init__()
        self.learning_rate = learning_rate
        self.T_max = step_each_epoch * epochs
        self.last_epoch = last_epoch
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)
        self.cycle = round(cycle * step_each_epoch)

    def __call__(self):
        learning_rate = CyclicalCosineDecay(
            learning_rate=self.learning_rate,
            T_max=self.T_max,
            cycle=self.cycle,
            last_epoch=self.last_epoch)
        if self.warmup_epoch > 0:
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate
B
bupt906 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277


class OneCycle(object):
    """
    One Cycle learning rate decay
    Args:
        max_lr(float): Upper learning rate boundaries
        epochs(int): total training epochs
        step_each_epoch(int): steps each epoch
        anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing. 
            Default: ‘cos’
        three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’ 
            instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
                 max_lr,
                 epochs,
                 step_each_epoch,
                 anneal_strategy='cos',
                 three_phase=False,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(OneCycle, self).__init__()
        self.max_lr = max_lr
        self.epochs = epochs
        self.steps_per_epoch = step_each_epoch
        self.anneal_strategy = anneal_strategy
        self.three_phase = three_phase
        self.last_epoch = last_epoch
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)

    def __call__(self):
        learning_rate = OneCycleDecay(
            max_lr=self.max_lr,
            epochs=self.epochs,
            steps_per_epoch=self.steps_per_epoch,
            anneal_strategy=self.anneal_strategy,
            three_phase=self.three_phase,
            last_epoch=self.last_epoch)
        if self.warmup_epoch > 0:
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
                end_lr=self.max_lr,
                last_epoch=self.last_epoch)
文幕地方's avatar
文幕地方 已提交
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
        return learning_rate


class Const(object):
    """
    Const learning rate decay
    Args:
        learning_rate(float): initial learning rate
        step_each_epoch(int): steps each epoch
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(Const, self).__init__()
        self.learning_rate = learning_rate
        self.last_epoch = last_epoch
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)

    def __call__(self):
        learning_rate = self.learning_rate
        if self.warmup_epoch > 0:
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
        return learning_rate
文幕地方's avatar
文幕地方 已提交
311 312


W
wangjingyeye 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
class DecayLearningRate(object):
    """
    DecayLearningRate learning rate decay
    new_lr = (lr - end_lr) * (1 - epoch/decay_steps)**power + end_lr
    Args:
        learning_rate(float): initial learning rate
        step_each_epoch(int): steps each epoch
        epochs(int): total training epochs
        factor(float): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 0.9
        end_lr(float): The minimum final learning rate. Default: 0.0.
    """

    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
                 factor=0.9,
                 end_lr=0,
                 **kwargs):
        super(DecayLearningRate, self).__init__()
        self.learning_rate = learning_rate
        self.epochs = epochs + 1
        self.factor = factor
        self.end_lr = 0
        self.decay_steps = step_each_epoch * epochs

    def __call__(self):
        learning_rate = lr.PolynomialDecay(
            learning_rate=self.learning_rate,
            decay_steps=self.decay_steps,
            power=self.factor,
            end_lr=self.end_lr)
        return learning_rate
346 347


文幕地方's avatar
文幕地方 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
class MultiStepDecay(object):
    """
    Piecewise learning rate decay
    Args:
        step_each_epoch(int): steps each epoch
        learning_rate (float): The initial learning rate. It is a python float number.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 step_each_epoch,
                 gamma,
                 warmup_epoch=0,
                 last_epoch=-1,
                 **kwargs):
        super(MultiStepDecay, self).__init__()
        self.milestones = [step_each_epoch * e for e in milestones]
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.last_epoch = last_epoch
        self.warmup_epoch = round(warmup_epoch * step_each_epoch)

    def __call__(self):
        learning_rate = lr.MultiStepDecay(
            learning_rate=self.learning_rate,
            milestones=self.milestones,
            gamma=self.gamma,
            last_epoch=self.last_epoch)
        if self.warmup_epoch > 0:
            learning_rate = lr.LinearWarmup(
                learning_rate=learning_rate,
                warmup_steps=self.warmup_epoch,
                start_lr=0.0,
                end_lr=self.learning_rate,
                last_epoch=self.last_epoch)
388
        return learning_rate