learning_rate_scheduler.py 23.3 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

M
minqiyang 已提交
17 18
import math

M
minqiyang 已提交
19 20
from .. import unique_name

21
__all__ = [
M
minqiyang 已提交
22
    'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay', 'ExponentialDecay',
M
minqiyang 已提交
23
    'InverseTimeDecay', 'PolynomialDecay', 'CosineDecay'
24
]
M
minqiyang 已提交
25 26 27 28 29


class LearningRateDecay(object):
    """
    Base class of learning rate decay
30 31 32 33
    
    Define the common interface of an LearningRateDecay.
    User should not use this class directly,
    but need to use one of it's implementation.
M
minqiyang 已提交
34 35
    """

M
minqiyang 已提交
36 37 38
    def __init__(self, begin=0, step=1, dtype='float32'):
        self.step_num = begin
        self.step_size = step
M
minqiyang 已提交
39 40 41 42 43
        self.dtype = dtype

    def __call__(self):
        lr = self.step()
        if isinstance(lr, float):
M
minqiyang 已提交
44
            lr = self.create_lr_var(lr)
M
minqiyang 已提交
45
        self.step_num += self.step_size
M
minqiyang 已提交
46 47
        return lr

M
minqiyang 已提交
48
    def create_lr_var(self, lr):
49 50 51 52 53 54 55 56
        """
        convert lr from float to variable

        Args: 
            lr: learning rate
        Returns:
            learning rate variable
        """
M
minqiyang 已提交
57
        from .. import layers
M
minqiyang 已提交
58 59 60 61 62
        lr = layers.create_global_var(
            name=unique_name.generate("learning_rate"),
            shape=[1],
            value=float(lr),
            dtype=self.dtype,
Z
Zeng Jinle 已提交
63
            persistable=False)
M
minqiyang 已提交
64
        return lr
M
minqiyang 已提交
65 66 67 68 69

    def step(self):
        raise NotImplementedError()


M
minqiyang 已提交
70
class PiecewiseDecay(LearningRateDecay):
71
    """
D
DuYao 已提交
72
    Piecewise decay scheduler.
73 74 75 76 77

    The algorithm can be described as the code below.

    .. code-block:: text

D
DuYao 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90
        boundaries = [10000, 20000]
        values = [1.0, 0.5, 0.1]
        if global_step < 10000:
            learning_rate = 1.0
        elif 10000 <= global_step < 20000:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Parameters:
        boundaries(list): A list of steps numbers. The type of element in the list is python int. 
        values(list): A list of learning rate values that will be picked during
            different step boundaries. The type of element in the list is python float.
T
tianshuo78520a 已提交
91
        begin(int): The begin step to initialize the global_step in the description above.
D
DuYao 已提交
92
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
93
            The default value is 1.
D
DuYao 已提交
94 95
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
96

97
    Returns:
D
DuYao 已提交
98
        None.
99

100 101 102 103 104 105 106
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          with fluid.dygraph.guard():
107
              emb = fluid.dygraph.Embedding( [10, 10] )
108
              optimizer = fluid.optimizer.SGD(
109 110
                 learning_rate=fluid.dygraph.PiecewiseDecay(boundaries, values, 0),
                 parameter_list = emb.parameters() )
111 112
    """

M
minqiyang 已提交
113 114
    def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
        super(PiecewiseDecay, self).__init__(begin, step, dtype)
M
minqiyang 已提交
115 116 117 118 119
        self.boundaries = boundaries
        self.values = values

        self.vars = []
        for value in values:
120
            self.vars.append(value)
M
minqiyang 已提交
121 122

    def step(self):
M
minqiyang 已提交
123 124
        for i in range(len(self.boundaries)):
            if self.step_num < self.boundaries[i]:
M
minqiyang 已提交
125
                return self.vars[i]
126
        return self.create_lr_var(self.vars[len(self.values) - 1])
127 128 129


class NaturalExpDecay(LearningRateDecay):
130 131 132
    """
    Applies natural exponential decay to the initial learning rate.
    
D
DuYao 已提交
133
    The algorithm can be described as following.
134

D
DuYao 已提交
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
    .. math::

        decayed\_learning\_rate = learning\_rate * e^{y} 

    If staircase is set to False, then:

    .. math::

        y = - decay\_rate * \\frac{global\_step}{decay\_steps}

    If staircase is set to True, then:

    .. math::

        y = - decay\_rate * math.floor(\\frac{global\_step}{decay\_steps}) 

    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(int): The decay rate.
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The 
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
161
            The default value is 1.
D
DuYao 已提交
162 163
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
164

165
    Returns:
D
DuYao 已提交
166
        None.
167

168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
              sgd_optimizer = fluid.optimizer.SGD(
        	      learning_rate=fluid.dygraph.NaturalExpDecay(
	    	            learning_rate=base_lr,
        		    decay_steps=10000,
		            decay_rate=0.5,
		            staircase=True))

    """

183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(NaturalExpDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)
        decayed_lr = self.learning_rate * layers.exp(-1 * self.decay_rate *
                                                     div_res)

        return decayed_lr


class ExponentialDecay(LearningRateDecay):
209 210 211
    """
    Applies exponential decay to the learning rate.

D
DuYao 已提交
212
    The algorithm can be described as following.
213
    
D
DuYao 已提交
214
    .. math::
215

D
DuYao 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        decayed\_learning\_rate = learning\_rate * decay\_rate ^ y 

    If staircase is set to False, then:

    .. math::

        y = \\frac{global\_step}{decay\_steps} 

    If staircase is set to True, then:

    .. math::

        y = math.floor(\\frac{global\_step}{decay\_steps})


    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(float): The decay rate.
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The 
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
241
            The default value is 1.
D
DuYao 已提交
242 243
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
244

245
    Returns:
D
DuYao 已提交
246
        None.
247

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
              sgd_optimizer = fluid.optimizer.SGD(
    	            learning_rate=fluid.dygraph.ExponentialDecay(
		        learning_rate=base_lr,
    		        decay_steps=10000,
		        decay_rate=0.5,
		        staircase=True))

    """

263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(ExponentialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate * (self.decay_rate**div_res)

        return decayed_lr


class InverseTimeDecay(LearningRateDecay):
289 290 291
    """
    Applies inverse time decay to the initial learning rate.

D
DuYao 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
    The algorithm can be described as following.
    If staircase is set to False, then:

    .. math::

        decayed\_learning\_rate = \\frac{learning\_rate}{1 + decay\_rate * \\frac{global\_step}{decay\_step}}  

    If staircase is set to True, then:

    .. math::

        decayed\_learning\_rate = \\frac{learning\_rate}{1 + decay\_rate * math.floor(\\frac{global\_step}{decay\_step})}

    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(float): The decay rate.
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The 
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
315
            The default value is 1.
D
DuYao 已提交
316 317
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be 
            'float32', 'float64'. The default value is 'float32'.
318

319
    Returns:
D
DuYao 已提交
320
        None.
321

322 323 324 325 326 327
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
328
              emb = fluid.dygraph.Embedding([10, 10])
329 330 331 332 333
              sgd_optimizer = fluid.optimizer.SGD(
	          learning_rate=fluid.dygraph.InverseTimeDecay(
		        learning_rate=base_lr,
		        decay_steps=10000,
		        decay_rate=0.5,
334 335
		        staircase=True),
                  parameter_list = emb.parameters())
336 337 338

    """

339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 decay_rate,
                 staircase=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(InverseTimeDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        from .. import layers
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
            div_res = layers.floor(div_res)

        decayed_lr = self.learning_rate / (1 + self.decay_rate * div_res)

        return decayed_lr


class PolynomialDecay(LearningRateDecay):
365 366 367
    """
    Applies polynomial decay to the initial learning rate.

D
DuYao 已提交
368 369 370 371 372 373 374
    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

        decay\_steps & = decay\_steps * math.ceil(\\frac{global\_step}{decay\_steps}) 
375

D
DuYao 已提交
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
        decayed\_learning\_rate & = (learning\_rate-end\_learning\_rate)*(1-\\frac{global\_step}{decay\_steps})^{power}+end\_learning\_rate

    If cycle is set to False, then:

    .. math::

        global\_step & = min(global\_step, decay\_steps) 

        decayed\_learning\_rate & = (learning\_rate-end\_learning\_rate)*(1-\\frac{global\_step}{decay\_steps})^{power}+end\_learning\_rate

    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        decay_steps(int32): The decay step size. It determines the decay cycle.
        end_learning_rate(float, optional): The minimum final learning rate. The default value is 0.0001.
        power(float, optional): Power of polynomial. The default value is 1.0.
        cycle(bool, optional): If set true, decay the learning rate every decay_steps. The default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
396
            The default value is 1.
D
DuYao 已提交
397 398
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
399

400
    Returns:
D
DuYao 已提交
401
        None.
402

403 404 405 406 407 408 409 410
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          with fluid.dygraph.guard():
411
              emb = fluid.dygraph.Embedding( [10, 10])
412 413
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.PolynomialDecay(
414 415
                  start_lr, total_step, end_lr, power=1.0),
                  parameter_list = emb.parameters())
416 417 418

    """

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_learning_rate=0.0001,
                 power=1.0,
                 cycle=False,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(PolynomialDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.cycle = cycle

    def step(self):
        from .. import layers
M
minqiyang 已提交
437 438
        tmp_step_num = self.step_num
        tmp_decay_steps = self.decay_steps
439 440
        if self.cycle:
            div_res = layers.ceil(
M
minqiyang 已提交
441
                self.create_lr_var(tmp_step_num / float(self.decay_steps)))
442

M
minqiyang 已提交
443 444
            if tmp_step_num == 0:
                div_res = self.create_lr_var(1.0)
M
minqiyang 已提交
445
            tmp_decay_steps = self.decay_steps * div_res
446
        else:
M
minqiyang 已提交
447 448 449 450 451 452 453
            tmp_step_num = self.create_lr_var(tmp_step_num
                                              if tmp_step_num < self.decay_steps
                                              else self.decay_steps)

        decayed_lr = (self.learning_rate - self.end_learning_rate) * \
            ((1 - tmp_step_num / tmp_decay_steps) ** self.power) + self.end_learning_rate
        return decayed_lr
454

M
minqiyang 已提交
455 456

class CosineDecay(LearningRateDecay):
457 458 459
    """
    Applies cosine decay to the learning rate.

D
DuYao 已提交
460
    The algorithm can be described as following.
461 462 463

    .. math::

D
DuYao 已提交
464
        decayed\_learning\_rate = learning\_rate * 0.5 * (math.cos(global\_step * \\frac{math.pi}{step\_each\_epoch} ) + 1)
465
    
D
DuYao 已提交
466 467 468 469 470 471 472 473
    Parameters:
        learning_rate(Variable|float): The initial learning rate. If the type 
            is Variable, it's a tensor with shape [1], the data type can be  
            float32 or float64. It also can be set to python int number.
        step_each_epoch(int): The number of steps in an epoch.
        epochs(int): The number of epochs.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
474
            The default value is 1.
D
DuYao 已提交
475 476
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
477

478
    Returns:
D
DuYao 已提交
479
        None.
480

481 482 483 484 485 486 487 488 489 490
    Examples:
	.. code-block:: python

  	    base_lr = 0.1
            with fluid.dygraph.guard():
                optimizer  = fluid.optimizer.SGD(
        	    learning_rate = fluid.dygraph.CosineDecay(
	                    base_lr, 10000, 120) )
    """

M
minqiyang 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
    def __init__(self,
                 learning_rate,
                 step_each_epoch,
                 epochs,
                 begin=0,
                 step=1,
                 dtype='float32'):
        super(CosineDecay, self).__init__(begin, step, dtype)
        self.learning_rate = learning_rate
        self.step_each_epoch = step_each_epoch
        self.epochs = epochs

    def step(self):
        from .. import layers
        cur_epoch = layers.floor(
            self.create_lr_var(self.step_num / self.step_each_epoch))
        decayed_lr = self.learning_rate * 0.5 * (
            layers.cos(cur_epoch * math.pi / self.epochs) + 1)
        return decayed_lr


class NoamDecay(LearningRateDecay):
513
    """
D
DuYao 已提交
514 515 516 517 518 519
    Applies Noam decay to the initial learning rate. 

    The algorithm can be described as following.

    .. math::

520
        decayed\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(global\_step^{-0.5}, global\_step * warmup\_steps^{-1.5})
D
DuYao 已提交
521 522 523 524 525 526 527 528 529 530

    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_ 

    Parameters:
        d$_{model}$(Variable|int): The dimensionality of input and output feature vector of model. If type is Variable, 
            it's a tensor with shape [1] and the data type can be int32 or int64. The type can also be python int.
        warmup_steps(Variable|int): The number of warmup steps. A super parameter. If type is Variable, 
            it's a tensor with shape [1] and the data type can be int32 or int64. The type can also be python int.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
531
            The default value is 1.
D
DuYao 已提交
532 533
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
534 535 536
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0
537

538
    Returns:
D
DuYao 已提交
539
        None.
540

541 542 543 544 545 546 547
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          warmup_steps = 100
          learning_rate = 0.01
          with fluid.dygraph.guard():
548
              emb = fluid.dygraph.Embedding([10, 10])
549 550 551
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.NoamDecay(
                         1/(warmup_steps *(learning_rate ** 2)),
552 553
                         warmup_steps),
                  parameter_list = emb.parameters())
554 555
    """

556 557 558 559 560 561 562
    def __init__(self,
                 d_model,
                 warmup_steps,
                 begin=1,
                 step=1,
                 dtype='float32',
                 learning_rate=1.0):
M
minqiyang 已提交
563
        super(NoamDecay, self).__init__(begin, step, dtype)
564
        self.learning_rate = learning_rate
M
minqiyang 已提交
565 566 567 568 569
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def step(self):
        from .. import layers
M
minqiyang 已提交
570 571
        a = self.create_lr_var(self.step_num**-0.5)
        b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
572 573
        lr_value = self.learning_rate * (self.d_model
                                         **-0.5) * layers.elementwise_min(a, b)
M
minqiyang 已提交
574
        return lr_value
H
hong 已提交
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605


class LinearLrWarmup(LearningRateDecay):
    """
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
    
    When global_step < warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
    
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
    
    When global_step >= warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            lr = learning_rate
    
    where lr is the learning_rate after warm-up.
    
    Args:
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
606
            The default value is 1.
H
hong 已提交
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
    
    Returns:
        Variable: Warm-up learning rate with the same data type as learning_rate.
    
    
    Examples:
    
    .. code-block:: python
    
        import paddle.fluid as fluid
    
        learning_rate = 0.1 
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1

        with fluid.dygraph.guard(): 
            lr_decay = fluid.dygraph.LinearLrWarmup( learning_rate, warmup_steps, start_lr, end_lr)
    
       
    """

    def __init__(self,
                 learning_rate,
                 warmup_steps,
                 start_lr,
                 end_lr,
                 begin=1,
                 step=1,
                 dtype='float32'):
        super(LinearLrWarmup, self).__init__(begin, step, dtype)
        type_check = isinstance(learning_rate, float) or isinstance(
            learning_rate, int) or isinstance(learning_rate, LearningRateDecay)
        if not type_check:
            raise TypeError(
                "the type of learning_rate should be [int, float or LearningRateDecay], the current type is {}".
                format(learning_rate))
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps
Z
Zeng Jinle 已提交
648 649
        assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
            end_lr, start_lr)
H
hong 已提交
650 651 652 653 654 655 656 657 658 659 660 661 662
        self.lr_ratio_before_warmup = (
            float(end_lr) - float(start_lr)) / float(warmup_steps)

    def step(self):
        base_lr = self.learning_rate
        if isinstance(self.learning_rate, LearningRateDecay):
            base_lr = base_lr()

        from .. import layers
        if self.step_num < self.warmup_steps:
            return self.lr_ratio_before_warmup * self.step_num
        else:
            return base_lr