regularizer.py 14.4 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from __future__ import print_function
16
import logging
17

18
from . import framework
19
from .framework import in_dygraph_mode, _varbase_creator
C
chengduoZH 已提交
20
from . import core
21

Y
yuyang18 已提交
22
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
23 24


25
def _create_regularization_of_grad(param, grad, regularization=None):
26 27 28 29 30
    """ Create and add backward regularization Operators

    Function helper of append_regularization_ops.
    """
    # If no gradient or no regularization is specified,  then we don't need to do anything
31 32 33
    if grad is None or ((not hasattr(param, 'regularizer') or (
            hasattr(param, 'regularizer') and param.regularizer is None)) and
                        regularization is None):
34 35
        return grad
    regularization_term = None
36
    if hasattr(param, 'regularizer') and param.regularizer is not None:
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
        # Add variable for regularization term in grad block
        regularization_term = param.regularizer(param, grad, grad.block)
    elif regularization is not None:
        regularization_term = regularization(param, grad, grad.block)

    assert regularization_term is not None

    new_grad = grad
    if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
        # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
        # the grad's type and name will be changed. But the gradient's name
        # is used in ParallelExecutor Reduce mode, so I add a flag for
        # the new_grad here.
        new_grad = grad.block.create_var(
            name=grad.name + core.kNewGradSuffix(),
            dtype=param.dtype,
            shape=param.shape,
            lod_level=param.lod_level,
            type=core.VarDesc.VarType.LOD_TENSOR)

    inputs = {"X": [grad, regularization_term]}
    outputs = {"Out": [new_grad]}
    if in_dygraph_mode():
60
        new_grad = core.ops.sum([grad, regularization_term])
61 62 63 64 65 66
    else:
        grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)

    return new_grad


67
def append_regularization_ops(parameters_and_grads, regularization=None):
68
    r"""Create and add backward regularization Operators
69 70 71 72 73 74 75 76 77

    Creates and adds backward regularization operators in the BlockDesc.
    This will add gradients of the regularizer function to the gradients
    of the parameters and return these modified gradients. This is the
    same as implementing weight decay in optimizers for regularization.

    Args:
        parameters_and_grads: A list of (parameters, gradients) pairs
                              that need to be regularized.
D
dzhwinter 已提交
78 79
        regularization: A global regularizer. If the parameter is not
                        set. It will be applied with regularizer.
80 81

    Returns:
82 83
        list[(Variable, Variable)]: list of (parameters, gradients) \
        pair with the regularized gradient
84 85 86 87 88

    Raises:
        Exception: Unknown regularization type
    """
    params_and_grads = []
89 90
    if in_dygraph_mode():
        for param, grad in parameters_and_grads:
91 92
            new_grad = _create_regularization_of_grad(param, grad,
                                                      regularization)
C
chengduo 已提交
93
            params_and_grads.append((param, new_grad))
94
    else:
95
        repeate_regularizer = False
96 97
        with framework.name_scope('regularization'):
            for param, grad in parameters_and_grads:
98 99 100 101 102 103
                if not repeate_regularizer and param.regularizer is not None and regularization is not None:
                    repeate_regularizer = True
                    logging.info(
                        "If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
                        "The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
                        % regularization.__str__())
104 105 106 107
                with param.block.program._optimized_guard([param, grad]):
                    new_grad = _create_regularization_of_grad(param, grad,
                                                              regularization)
                    params_and_grads.append((param, new_grad))
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    return params_and_grads


class WeightDecayRegularizer(object):
    """Base class for weight decay regularizers

    Defines the common interface of weight-decay regularizers.
    Weight-decay regularizers are added only during the backward
    pass for faster regularization. They add operations to the network
    that correspond to gradient of the regularization function.
    Users should not use this class directly, but need to use one
    of its implementations
    """

    def __init__(self):
        pass

C
chengduoZH 已提交
125
    def __call__(self, param, grad, block):
126 127 128 129
        """Add corresponding weight decay operations to the network
        """
        raise NotImplementedError()

F
fengjiayi 已提交
130 131 132 133 134
    def __str__(self):
        """Debug string
        """
        raise NotImplementedError()

135 136

class L2DecayRegularizer(WeightDecayRegularizer):
137
    r""" 
138
    Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
139

140 141 142 143 144
    It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_fluid_optimizer_SGDOptimizer` ). 
    When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in 
    ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has 
    higher priority than ``optimizer`` .
    
145
    In the implementation, the formula of L2 Weight Decay Regularization is as follows:
146 147 148 149 150 151

    .. math::

        L2WeightDecay = reg\_coeff * parameter

    Args:
152
        regularization_coeff(float, optional): regularization coeff. Default:0.0
153 154 155 156

    Examples:
        .. code-block:: python

157
            # Example1: set Regularizer in optimizer
158
            import paddle.fluid as fluid
159

160 161 162 163 164 165 166 167 168
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
                data = fluid.layers.data(name='image', shape=[3, 28, 28], dtype='float32')
                label = fluid.layers.data(name='label', shape=[1], dtype='int64')
                hidden = fluid.layers.fc(input=data, size=128, act='relu')
                prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
                loss = fluid.layers.cross_entropy(input=prediction, label=label)
                avg_loss = fluid.layers.mean(loss)
169 170
            optimizer = fluid.optimizer.Adagrad(
                learning_rate=1e-4,
171
                regularization=fluid.regularizer.L2Decay(
172
                    regularization_coeff=0.1))
173
            optimizer.minimize(avg_loss)
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197


            # Example2: set Regularizer both in ParamAttr and optimizer
            import paddle.fluid as fluid

            l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1)
            l2 = fluid.regularizer.L2Decay(regularization_coeff=0.1)
            x = fluid.layers.uniform_random([3,4])
            
            # set L1 regularization in fluid.ParamAttr
            w_param = fluid.ParamAttr(regularizer=l1)
            hidden1 = fluid.layers.fc(x, 8, param_attr=w_param)  # fc_0.w_0(L1), fc_0.b_0
            hidden2 = fluid.layers.fc(hidden1, 16, param_attr=w_param)   # fc_1.w_0(L1), fc_1.b_0
            predict = fluid.layers.fc(hidden2, 32)    # fc_3.w_0, fc_3.b_0
            avg_loss = fluid.layers.mean(predict)

            # set L2 regularization in optimizer
            optimizer = fluid.optimizer.SGD(learning_rate=1e-4, regularization=l2)
            optimizer.minimize(avg_loss)
            
            # it will Print Message:
            # Regularization of [fc_0.w_0, fc_1.w_0] have been set by ParamAttr or WeightNormParamAttr already. 
            # So, the Regularization of Optimizer will not take effect for these parameters!

198 199 200 201 202 203 204
    """

    def __init__(self, regularization_coeff=0.0):
        assert regularization_coeff is not None
        super(L2DecayRegularizer, self).__init__()
        self._regularization_coeff = regularization_coeff

C
chengduoZH 已提交
205
    def __call__(self, param, grad, block):
206 207 208 209 210 211 212 213 214 215 216 217
        """Add L2 weight decay ops to network

        Adds L2 weight decay ops.
        L2WeightDecay = reg_coeff * parameter

        Args:
            param: parameter variable for which regularization is applied
            block: block in which variable is to be created

        Returns:
            new variable for weight decay
        """
218
        assert isinstance(param, framework.Variable)
219
        assert isinstance(block, framework.Block)
C
chengduoZH 已提交
220

221 222 223
        inputs = {"X": [param]}
        attrs = {"scale": self._regularization_coeff}

H
Hongyu Liu 已提交
224
        if framework.in_dygraph_mode():
225
            return core.ops.scale(param, "scale", self._regularization_coeff)
H
Hongyu Liu 已提交
226 227 228
        else:
            decay = block.create_var(
                dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
C
chengduoZH 已提交
229

230 231 232 233 234 235
            # Append Op to calculate decay
            block.append_op(
                type='scale',
                inputs={"X": param},
                outputs={"Out": decay},
                attrs={"scale": self._regularization_coeff})
236

237
            return decay
238

F
fengjiayi 已提交
239 240 241
    def __str__(self):
        return "L2Decay, regularization_coeff=%f" % self._regularization_coeff

242 243

class L1DecayRegularizer(WeightDecayRegularizer):
244
    r"""
245 246
    Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse.
    
247 248 249 250 251
    It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_fluid_optimizer_SGDOptimizer` ). 
    When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in 
    ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has 
    higher priority than ``optimizer`` .
    
252 253
    In the implementation, the formula of L1 Weight Decay Regularization is as follows:
	
254 255 256 257 258
    .. math::

        L1WeightDecay = reg\_coeff * sign(parameter)

    Args:
259
        regularization_coeff(float, optional): regularization coeff. Default:0.0.
260
	
261 262 263
    Examples:
        .. code-block:: python

264
            # Example1: set Regularizer in optimizer
265
            import paddle.fluid as fluid
266

267 268 269 270 271 272 273 274 275
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
                data = fluid.layers.data(name='image', shape=[3, 28, 28], dtype='float32')
                label = fluid.layers.data(name='label', shape=[1], dtype='int64')
                hidden = fluid.layers.fc(input=data, size=128, act='relu')
                prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
                loss = fluid.layers.cross_entropy(input=prediction, label=label)
                avg_loss = fluid.layers.mean(loss)
X
Xin Pan 已提交
276 277 278 279
            optimizer = fluid.optimizer.Adagrad(
                learning_rate=1e-4,
                regularization=fluid.regularizer.L1DecayRegularizer(
                    regularization_coeff=0.1))
280
            optimizer.minimize(avg_loss)
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
 

            # Example2: set Regularizer both in ParamAttr and optimizer
            import paddle.fluid as fluid

            l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1)
            l2 = fluid.regularizer.L2Decay(regularization_coeff=0.1)
            x = fluid.layers.uniform_random([3,4])
            
            # set L1 regularization in fluid.ParamAttr
            w_param = fluid.ParamAttr(regularizer=l1)
            hidden1 = fluid.layers.fc(x, 8, param_attr=w_param)  # fc_0.w_0(L1), fc_0.b_0
            hidden2 = fluid.layers.fc(hidden1, 16, param_attr=w_param)  # fc_1.w_0(L1), fc_1.b_0
            predict = fluid.layers.fc(hidden2, 32)   # fc_3.w_0, fc_3.b_0
            avg_loss = fluid.layers.mean(predict)

            # set L2 regularization in optimizer
            optimizer = fluid.optimizer.SGD(learning_rate=1e-4, regularization=l2)
            optimizer.minimize(avg_loss)
            
            # it will Print Message:
            # Regularization of [fc_0.w_0, fc_1.w_0] have been set by ParamAttr or WeightNormParamAttr already. 
            # So, the Regularization of Optimizer will not take effect for these parameters!

305 306 307 308 309 310 311
    """

    def __init__(self, regularization_coeff=0.0):
        assert regularization_coeff is not None
        super(L1DecayRegularizer, self).__init__()
        self._regularization_coeff = regularization_coeff

C
chengduoZH 已提交
312
    def __call__(self, param, grad, block):
313 314 315 316 317 318 319 320 321 322 323 324
        """Add L1 weight decay ops to network

        Adds L1 weight decay ops.
        L1WeightDecay = reg_coeff * sign(parameter)

        Args:
            param: parameter variable for which regularization is applied
            block: block in which variable is to be created

        Returns:
            new variable for weight decay
        """
325
        assert isinstance(param, framework.Variable)
326
        assert isinstance(block, framework.Block)
C
chengduo 已提交
327

H
Hongyu Liu 已提交
328
        if framework.in_dygraph_mode():
329
            sign = block.create_var(dtype=param.dtype, shape=param.shape)
H
Hongyu Liu 已提交
330 331
            decay = block.create_var(dtype=param.dtype, shape=param.shape)
        else:
332 333
            sign = block.create_var(
                dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
H
Hongyu Liu 已提交
334 335
            decay = block.create_var(
                dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
C
chengduoZH 已提交
336

337
        # Append sign op
338
        block.append_op(type='sign', inputs={"X": param}, outputs={"Out": sign})
339 340 341 342

        # Append scale op to the output of sign op
        block.append_op(
            type='scale',
343
            inputs={"X": sign},
344 345 346 347
            outputs={"Out": decay},
            attrs={"scale": self._regularization_coeff})

        return decay
348

F
fengjiayi 已提交
349 350 351
    def __str__(self):
        return "L1Decay, regularization_coeff=%f" % self._regularization_coeff

352 353 354 355 356 357 358

# We short the class name, since users will use the regulaizer with the package
# name. The sample code:
#
# import paddle.fluid as fluid
#
# hidden = fluid.layers.fc(...,
Y
Yu Yang 已提交
359
#                          param_attr=fluid.regularizer.Xavier())
360 361 362 363
#
# It is no need to add a `Regularizer` as the class suffix
L1Decay = L1DecayRegularizer
L2Decay = L2DecayRegularizer