regularizer.py 14.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from __future__ import print_function
16
import logging
17

18
from . import framework
19
from .framework import in_dygraph_mode, _varbase_creator
C
chengduoZH 已提交
20
from . import core
21

Y
yuyang18 已提交
22
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
23 24


25
def _create_regularization_of_grad(param, grad, regularization=None):
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    """ Create and add backward regularization Operators

    Function helper of append_regularization_ops.
    """
    # If no gradient or no regularization is specified,  then we don't need to do anything
    if grad is None or (param.regularizer is None and regularization is None):
        return grad
    regularization_term = None
    if param.regularizer is not None:
        # Add variable for regularization term in grad block
        regularization_term = param.regularizer(param, grad, grad.block)
    elif regularization is not None:
        regularization_term = regularization(param, grad, grad.block)

    assert regularization_term is not None

    new_grad = grad
    if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
        # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
        # the grad's type and name will be changed. But the gradient's name
        # is used in ParallelExecutor Reduce mode, so I add a flag for
        # the new_grad here.
        new_grad = grad.block.create_var(
            name=grad.name + core.kNewGradSuffix(),
            dtype=param.dtype,
            shape=param.shape,
            lod_level=param.lod_level,
            type=core.VarDesc.VarType.LOD_TENSOR)

    inputs = {"X": [grad, regularization_term]}
    outputs = {"Out": [new_grad]}
    if in_dygraph_mode():
58
        new_grad = core.ops.sum([grad, regularization_term])
59 60 61 62 63 64
    else:
        grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)

    return new_grad


65
def append_regularization_ops(parameters_and_grads, regularization=None):
66 67 68 69 70 71 72 73 74 75
    """Create and add backward regularization Operators

    Creates and adds backward regularization operators in the BlockDesc.
    This will add gradients of the regularizer function to the gradients
    of the parameters and return these modified gradients. This is the
    same as implementing weight decay in optimizers for regularization.

    Args:
        parameters_and_grads: A list of (parameters, gradients) pairs
                              that need to be regularized.
D
dzhwinter 已提交
76 77
        regularization: A global regularizer. If the parameter is not
                        set. It will be applied with regularizer.
78 79

    Returns:
80 81
        list[(Variable, Variable)]: list of (parameters, gradients) \
        pair with the regularized gradient
82 83 84 85 86

    Raises:
        Exception: Unknown regularization type
    """
    params_and_grads = []
87 88
    if in_dygraph_mode():
        for param, grad in parameters_and_grads:
89 90
            new_grad = _create_regularization_of_grad(param, grad,
                                                      regularization)
C
chengduo 已提交
91
            params_and_grads.append((param, new_grad))
92
    else:
93
        repeate_regularizer = False
94 95
        with framework.name_scope('regularization'):
            for param, grad in parameters_and_grads:
96 97 98 99 100 101
                if not repeate_regularizer and param.regularizer is not None and regularization is not None:
                    repeate_regularizer = True
                    logging.info(
                        "If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
                        "The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
                        % regularization.__str__())
102 103 104 105
                with param.block.program._optimized_guard([param, grad]):
                    new_grad = _create_regularization_of_grad(param, grad,
                                                              regularization)
                    params_and_grads.append((param, new_grad))
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    return params_and_grads


class WeightDecayRegularizer(object):
    """Base class for weight decay regularizers

    Defines the common interface of weight-decay regularizers.
    Weight-decay regularizers are added only during the backward
    pass for faster regularization. They add operations to the network
    that correspond to gradient of the regularization function.
    Users should not use this class directly, but need to use one
    of its implementations
    """

    def __init__(self):
        pass

C
chengduoZH 已提交
123
    def __call__(self, param, grad, block):
124 125 126 127
        """Add corresponding weight decay operations to the network
        """
        raise NotImplementedError()

F
fengjiayi 已提交
128 129 130 131 132
    def __str__(self):
        """Debug string
        """
        raise NotImplementedError()

133 134

class L2DecayRegularizer(WeightDecayRegularizer):
135 136
    """ 
    Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
137

138 139 140 141 142
    It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_fluid_optimizer_SGDOptimizer` ). 
    When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in 
    ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has 
    higher priority than ``optimizer`` .
    
143
    In the implementation, the formula of L2 Weight Decay Regularization is as follows:
144 145 146 147 148 149

    .. math::

        L2WeightDecay = reg\_coeff * parameter

    Args:
150
        regularization_coeff(float, optional): regularization coeff. Default:0.0
151 152 153 154

    Examples:
        .. code-block:: python

155
            # Example1: set Regularizer in optimizer
156
            import paddle.fluid as fluid
157

158 159 160 161 162 163 164 165 166
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
                data = fluid.layers.data(name='image', shape=[3, 28, 28], dtype='float32')
                label = fluid.layers.data(name='label', shape=[1], dtype='int64')
                hidden = fluid.layers.fc(input=data, size=128, act='relu')
                prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
                loss = fluid.layers.cross_entropy(input=prediction, label=label)
                avg_loss = fluid.layers.mean(loss)
167 168
            optimizer = fluid.optimizer.Adagrad(
                learning_rate=1e-4,
169
                regularization=fluid.regularizer.L2Decay(
170
                    regularization_coeff=0.1))
171
            optimizer.minimize(avg_loss)
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195


            # Example2: set Regularizer both in ParamAttr and optimizer
            import paddle.fluid as fluid

            l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1)
            l2 = fluid.regularizer.L2Decay(regularization_coeff=0.1)
            x = fluid.layers.uniform_random([3,4])
            
            # set L1 regularization in fluid.ParamAttr
            w_param = fluid.ParamAttr(regularizer=l1)
            hidden1 = fluid.layers.fc(x, 8, param_attr=w_param)  # fc_0.w_0(L1), fc_0.b_0
            hidden2 = fluid.layers.fc(hidden1, 16, param_attr=w_param)   # fc_1.w_0(L1), fc_1.b_0
            predict = fluid.layers.fc(hidden2, 32)    # fc_3.w_0, fc_3.b_0
            avg_loss = fluid.layers.mean(predict)

            # set L2 regularization in optimizer
            optimizer = fluid.optimizer.SGD(learning_rate=1e-4, regularization=l2)
            optimizer.minimize(avg_loss)
            
            # it will Print Message:
            # Regularization of [fc_0.w_0, fc_1.w_0] have been set by ParamAttr or WeightNormParamAttr already. 
            # So, the Regularization of Optimizer will not take effect for these parameters!

196 197 198 199 200 201 202
    """

    def __init__(self, regularization_coeff=0.0):
        assert regularization_coeff is not None
        super(L2DecayRegularizer, self).__init__()
        self._regularization_coeff = regularization_coeff

C
chengduoZH 已提交
203
    def __call__(self, param, grad, block):
204 205 206 207 208 209 210 211 212 213 214 215 216 217
        """Add L2 weight decay ops to network

        Adds L2 weight decay ops.
        L2WeightDecay = reg_coeff * parameter

        Args:
            param: parameter variable for which regularization is applied
            block: block in which variable is to be created

        Returns:
            new variable for weight decay
        """
        assert isinstance(param, framework.Parameter)
        assert isinstance(block, framework.Block)
C
chengduoZH 已提交
218

219 220 221
        inputs = {"X": [param]}
        attrs = {"scale": self._regularization_coeff}

H
Hongyu Liu 已提交
222
        if framework.in_dygraph_mode():
223
            return core.ops.scale(param, "scale", self._regularization_coeff)
H
Hongyu Liu 已提交
224 225 226
        else:
            decay = block.create_var(
                dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
C
chengduoZH 已提交
227

228 229 230 231 232 233
            # Append Op to calculate decay
            block.append_op(
                type='scale',
                inputs={"X": param},
                outputs={"Out": decay},
                attrs={"scale": self._regularization_coeff})
234

235
            return decay
236

F
fengjiayi 已提交
237 238 239
    def __str__(self):
        return "L2Decay, regularization_coeff=%f" % self._regularization_coeff

240 241

class L1DecayRegularizer(WeightDecayRegularizer):
242 243 244
    """
    Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse.
    
245 246 247 248 249
    It can be set in :ref:`api_fluid_ParamAttr` or ``optimizer`` (such as :ref:`api_fluid_optimizer_SGDOptimizer` ). 
    When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in 
    ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has 
    higher priority than ``optimizer`` .
    
250 251
    In the implementation, the formula of L1 Weight Decay Regularization is as follows:
	
252 253 254 255 256
    .. math::

        L1WeightDecay = reg\_coeff * sign(parameter)

    Args:
257
        regularization_coeff(float, optional): regularization coeff. Default:0.0.
258
	
259 260 261
    Examples:
        .. code-block:: python

262
            # Example1: set Regularizer in optimizer
263
            import paddle.fluid as fluid
264

265 266 267 268 269 270 271 272 273
            main_prog = fluid.Program()
            startup_prog = fluid.Program()
            with fluid.program_guard(main_prog, startup_prog):
                data = fluid.layers.data(name='image', shape=[3, 28, 28], dtype='float32')
                label = fluid.layers.data(name='label', shape=[1], dtype='int64')
                hidden = fluid.layers.fc(input=data, size=128, act='relu')
                prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
                loss = fluid.layers.cross_entropy(input=prediction, label=label)
                avg_loss = fluid.layers.mean(loss)
X
Xin Pan 已提交
274 275 276 277
            optimizer = fluid.optimizer.Adagrad(
                learning_rate=1e-4,
                regularization=fluid.regularizer.L1DecayRegularizer(
                    regularization_coeff=0.1))
278
            optimizer.minimize(avg_loss)
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
 

            # Example2: set Regularizer both in ParamAttr and optimizer
            import paddle.fluid as fluid

            l1 = fluid.regularizer.L1Decay(regularization_coeff=0.1)
            l2 = fluid.regularizer.L2Decay(regularization_coeff=0.1)
            x = fluid.layers.uniform_random([3,4])
            
            # set L1 regularization in fluid.ParamAttr
            w_param = fluid.ParamAttr(regularizer=l1)
            hidden1 = fluid.layers.fc(x, 8, param_attr=w_param)  # fc_0.w_0(L1), fc_0.b_0
            hidden2 = fluid.layers.fc(hidden1, 16, param_attr=w_param)  # fc_1.w_0(L1), fc_1.b_0
            predict = fluid.layers.fc(hidden2, 32)   # fc_3.w_0, fc_3.b_0
            avg_loss = fluid.layers.mean(predict)

            # set L2 regularization in optimizer
            optimizer = fluid.optimizer.SGD(learning_rate=1e-4, regularization=l2)
            optimizer.minimize(avg_loss)
            
            # it will Print Message:
            # Regularization of [fc_0.w_0, fc_1.w_0] have been set by ParamAttr or WeightNormParamAttr already. 
            # So, the Regularization of Optimizer will not take effect for these parameters!

303 304 305 306 307 308 309
    """

    def __init__(self, regularization_coeff=0.0):
        assert regularization_coeff is not None
        super(L1DecayRegularizer, self).__init__()
        self._regularization_coeff = regularization_coeff

C
chengduoZH 已提交
310
    def __call__(self, param, grad, block):
311 312 313 314 315 316 317 318 319 320 321 322 323 324
        """Add L1 weight decay ops to network

        Adds L1 weight decay ops.
        L1WeightDecay = reg_coeff * sign(parameter)

        Args:
            param: parameter variable for which regularization is applied
            block: block in which variable is to be created

        Returns:
            new variable for weight decay
        """
        assert isinstance(param, framework.Parameter)
        assert isinstance(block, framework.Block)
C
chengduo 已提交
325

H
Hongyu Liu 已提交
326 327 328 329 330
        if framework.in_dygraph_mode():
            decay = block.create_var(dtype=param.dtype, shape=param.shape)
        else:
            decay = block.create_var(
                dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
C
chengduoZH 已提交
331

332 333 334 335 336 337 338 339 340 341 342 343
        # Append sign op
        block.append_op(
            type='sign', inputs={"X": param}, outputs={"Out": decay})

        # Append scale op to the output of sign op
        block.append_op(
            type='scale',
            inputs={"X": decay},
            outputs={"Out": decay},
            attrs={"scale": self._regularization_coeff})

        return decay
344

F
fengjiayi 已提交
345 346 347
    def __str__(self):
        return "L1Decay, regularization_coeff=%f" % self._regularization_coeff

348 349 350 351 352 353 354

# We short the class name, since users will use the regulaizer with the package
# name. The sample code:
#
# import paddle.fluid as fluid
#
# hidden = fluid.layers.fc(...,
Y
Yu Yang 已提交
355
#                          param_attr=fluid.regularizer.Xavier())
356 357 358 359
#
# It is no need to add a `Regularizer` as the class suffix
L1Decay = L1DecayRegularizer
L2Decay = L2DecayRegularizer