optimizer.py 4.4 KB
Newer Older
1 2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

5
import math
6 7 8 9 10 11
import logging

from paddle import fluid

import paddle.fluid.optimizer as optimizer
import paddle.fluid.regularizer as regularizer
12 13
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.layers.ops import cos
14 15 16

from ppdet.core.workspace import register, serializable

17
__all__ = ['Optimize']
18 19 20 21 22

logger = logging.getLogger(__name__)


@serializable
23
@register
24 25 26 27 28
class PiecewiseDecay(object):
    """
    Multi step learning rate decay

    Args:
F
FDInSky 已提交
29
        gamma (float | list): decay factor
30 31 32
        milestones (list): steps at which to decay learning rate
    """

33
    def __init__(self, gamma=[0.1, 0.01], milestones=[8, 11]):
34
        super(PiecewiseDecay, self).__init__()
F
FDInSky 已提交
35 36 37 38 39 40
        if type(gamma) is not list:
            self.gamma = []
            for i in range(len(milestones)):
                self.gamma.append(gamma / 10**i)
        else:
            self.gamma = gamma
41 42
        self.milestones = milestones

43 44 45 46 47
    def __call__(self,
                 base_lr=None,
                 boundary=None,
                 value=None,
                 step_per_epoch=None):
F
FDInSky 已提交
48
        if boundary is not None:
49
            boundary.extend(self.milestones * int(step_per_epoch))
50

F
FDInSky 已提交
51 52 53
        if value is not None:
            for i in self.gamma:
                value.append(base_lr * i)
54

F
FDInSky 已提交
55
        return fluid.dygraph.PiecewiseDecay(boundary, value, begin=0, step=1)
56 57


58
@serializable
59
@register
60 61 62 63 64 65 66 67 68 69 70 71 72 73
class LinearWarmup(object):
    """
    Warm up learning rate linearly

    Args:
        steps (int): warm up steps
        start_factor (float): initial learning rate factor
    """

    def __init__(self, steps=500, start_factor=1. / 3):
        super(LinearWarmup, self).__init__()
        self.steps = steps
        self.start_factor = start_factor

F
FDInSky 已提交
74 75 76 77 78 79 80 81 82 83
    def __call__(self, base_lr):
        boundary = []
        value = []
        for i in range(self.steps):
            alpha = i / self.steps
            factor = self.start_factor * (1 - alpha) + alpha
            lr = base_lr * factor
            value.append(lr)
            if i > 0:
                boundary.append(i)
84 85
        boundary.append(self.steps)
        value.append(base_lr)
F
FDInSky 已提交
86
        return boundary, value
87 88


89
@serializable
90
@register
91
class BaseLR(object):
92 93 94 95 96 97 98
    """
    Learning Rate configuration

    Args:
        base_lr (float): base learning rate
        schedulers (list): learning rate schedulers
    """
99
    __inject__ = ['decay', 'warmup']
100

101 102
    def __init__(self, base_lr=0.01, decay=None, warmup=None):
        super(BaseLR, self).__init__()
103
        self.base_lr = base_lr
104 105
        self.decay = decay
        self.warmup = warmup
106

107
    def __call__(self, step_per_epoch):
F
FDInSky 已提交
108
        # warmup
109
        boundary, value = self.warmup(self.base_lr)
F
FDInSky 已提交
110
        # decay
111
        decay_lr = self.decay(self.base_lr, boundary, value, step_per_epoch)
F
FDInSky 已提交
112
        return decay_lr
113 114 115


@register
116
class Optimize():
117 118 119 120 121 122 123 124
    """
    Build optimizer handles

    Args:
        regularizer (object): an `Regularizer` instance
        optimizer (object): an `Optimizer` instance
    """
    __category__ = 'optim'
125
    __inject__ = ['learning_rate']
126 127

    def __init__(self,
128 129 130 131 132 133 134
                 learning_rate,
                 optimizer={'name': 'Momentum',
                            'momentum': 0.9},
                 regularizer={'name': 'L2',
                              'factor': 0.0001},
                 clip_grad_by_norm=None):
        self.learning_rate = learning_rate
135
        self.optimizer = optimizer
136 137
        self.regularizer = regularizer
        self.clip_grad_by_norm = clip_grad_by_norm
138

139
    def __call__(self, params=None, step_per_epoch=1):
F
FDInSky 已提交
140

W
wangguanzhong 已提交
141
        if self.regularizer:
142
            reg_type = self.regularizer['name'] + 'Decay'
W
wangguanzhong 已提交
143 144 145 146
            reg_factor = self.regularizer['factor']
            regularization = getattr(regularizer, reg_type)(reg_factor)
        else:
            regularization = None
F
FDInSky 已提交
147

148 149 150 151 152
        if self.clip_grad_by_norm is not None:
            fluid.clip.set_gradient_clip(
                clip=fluid.clip.GradientClipByGlobalNorm(
                    clip_norm=self.clip_grad_by_norm))

153
        optim_args = self.optimizer.copy()
154 155
        optim_type = optim_args['name']
        del optim_args['name']
156
        op = getattr(optimizer, optim_type)
157 158

        return op(learning_rate=self.learning_rate(step_per_epoch),
F
FDInSky 已提交
159
                  parameter_list=params,
160 161
                  regularization=regularization,
                  **optim_args)