param_attr.py 12.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
update  
fengjiayi 已提交
14

15 16
from __future__ import print_function

17
import six
18
import warnings
19
import sys
20

21 22
from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer
23
from paddle.fluid.data_feeder import check_type
Y
Yu Yang 已提交
24

25 26 27 28
__all__ = [
    'ParamAttr',
    'WeightNormParamAttr',
]
Y
Yu Yang 已提交
29

Y
Yu Yang 已提交
30 31

class ParamAttr(object):
C
chengduoZH 已提交
32
    """
33

34 35
    Note:
        ``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0. 
36
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
37 38
        There are three clipping strategies: :ref:`api_paddle_nn_ClipGradByGlobalNorm` , 
        :ref:`api_paddle_nn_ClipGradByNorm` , :ref:`api_paddle_nn_ClipGradByValue` .
Z
Zeng Jinle 已提交
39

40 41 42 43
    Create a object to represent the attribute of parameter. The attributes are:
    name, initializer, learning rate, regularizer, trainable, gradient clip,
    and model average.

Z
Zeng Jinle 已提交
44 45 46 47 48 49
    Parameters:
        name (str, optional): The parameter's name. Default None, meaning that the name
                would be created automatically.
        initializer (Initializer, optional): The method to initial this parameter. Default
                None, meaning that the weight parameter is initialized by Xavier initializer,
                and the bias parameter is initialized by 0.
50
        learning_rate (float, optional): The parameter's learning rate. The learning rate when
Z
Zeng Jinle 已提交
51 52
                optimize is the global learning rates times the parameter's learning rate times
                the factor of learning rate scheduler. Default 1.0.
53
        regularizer (WeightDecayRegularizer, optional): Regularization strategy. There are two method: 
54 55
                :ref:`api_paddle_regularizer_L1Decay` , :ref:`api_paddle_regularizer_L2Decay` . If 
                regularizer is also set in ``optimizer`` (such as :ref:`api_paddle_optimizer_SGD` ), 
56 57
                that regularizer setting in optimizer will be ignored. Default None, meaning there is 
                no regularization.
58 59
        trainable (bool, optional): Whether this parameter is trainable. Default True.
        do_model_average (bool, optional): Whether this parameter should do model average
60
                when model average is enabled. Only used in ExponentialMovingAverage. Default True.
61 62 63 64
        need_clip (bool, optional): Whether the parameter gradient need to be cliped in optimizer. Default is True.

    Returns:
       ParamAttr Object.
C
chengduoZH 已提交
65 66

    Examples:
67
    
C
chengduoZH 已提交
68 69
        .. code-block:: python

70 71 72 73 74 75 76 77
            import paddle

            weight_attr = paddle.ParamAttr(name="weight",
                                           learning_rate=0.5,
                                           regularizer=paddle.regularizer.L2Decay(1.0),
                                           trainable=True)
            print(weight_attr.name) # "weight"
            paddle.nn.Linear(3, 4, weight_attr=weight_attr)
C
chengduoZH 已提交
78 79
    """

Y
Yu Yang 已提交
80 81 82 83 84
    def __init__(self,
                 name=None,
                 initializer=None,
                 learning_rate=1.0,
                 regularizer=None,
Y
Yu Yang 已提交
85
                 trainable=True,
86 87
                 do_model_average=True,
                 need_clip=True):
88 89 90 91 92 93 94 95

        if sys.version_info.major == 2:
            check_type(name, "name", (str, type(None), unicode), "ParamAttr")
        else:
            check_type(name, "name", (str, type(None)), "ParamAttr")
        check_type(learning_rate, "learning_rate", (float, int), "ParamAttr")
        check_type(trainable, "trainable", (bool), "ParamAttr")
        check_type(do_model_average, "do_model_average", (bool), "ParamAttr")
96
        check_type(need_clip, "need_clip", (bool), "ParamAttr")
97 98 99 100
        check_type(initializer, "initializer", (Initializer, type(None)),
                   "ParamAttr")
        check_type(regularizer, "regularizer",
                   (WeightDecayRegularizer, type(None)), "ParamAttr")
101

Y
Yu Yang 已提交
102
        self.name = name
103
        if self.name == "":
H
hong 已提交
104 105
            raise ValueError("name of ParamAttr can not be empty str")

Y
Yu Yang 已提交
106 107 108 109
        self.initializer = initializer
        self.learning_rate = learning_rate
        self.regularizer = regularizer
        self.trainable = trainable
110
        self.do_model_average = do_model_average
111
        self.need_clip = need_clip
Y
Yu Yang 已提交
112

Y
yuyang18 已提交
113
    def _set_default_initializer(self, initializer):
C
chengduoZH 已提交
114 115 116
        """
        Set the default initializer, the initializer should be Constant,
        Uniform, Normal, Xavier, MSRA.
C
chengduoZH 已提交
117 118 119 120 121 122

        Args:
            initializer(Initializer): the initializer to set.

        Returns:
            None
C
chengduoZH 已提交
123
        """
Y
Yu Yang 已提交
124 125 126 127 128 129 130 131 132 133
        if initializer is None:
            if self.initializer is None:
                raise ValueError("ParamAttr.initializer is not set")
            return

        if self.initializer is not None:
            return

        self.initializer = initializer

Y
yuyang18 已提交
134
    def _set_default_param_initializer(self):
C
chengduoZH 已提交
135 136
        """
        Set the default initializer for the parameter with Xavier.
C
chengduoZH 已提交
137 138 139 140 141 142

        Args:
            None.

        Returns:
            None.
C
chengduoZH 已提交
143
        """
Y
yuyang18 已提交
144
        self._set_default_initializer(Xavier())
Y
Yu Yang 已提交
145

Y
yuyang18 已提交
146
    def _set_default_bias_initializer(self):
C
chengduoZH 已提交
147 148
        """
        Set the default initializer for the bias with Constant(0.0).
C
chengduoZH 已提交
149 150 151 152 153 154

        Args:
            None.

        Returns:
            None.
C
chengduoZH 已提交
155
        """
Y
yuyang18 已提交
156
        self._set_default_initializer(Constant(0.0))
Y
Yu Yang 已提交
157 158

    @staticmethod
Y
yuyang18 已提交
159
    def _to_attr(arg):
C
chengduoZH 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173
        """
        Create ParamAttr[s].

        Args:
            arg: Arguments to initialize ParamAttr[s]. arg's type can be
                str, Initializer, float, WeightDecayRegularizer, BaseGradientClipAttr,
                bool, ParamAttr, or a list of above type.

        Returns:
            ParamAttr[s]: ParamAttr[s] initialized with arg.

        Raises:
            arg can not initialize a ParamAttr.
        """
Y
Yu Yang 已提交
174 175
        if arg is None:
            return ParamAttr()
176
        elif isinstance(arg, list) or isinstance(arg, tuple):
Y
yuyang18 已提交
177
            return [ParamAttr._to_attr(a) for a in arg]
Y
Yu Yang 已提交
178 179
        elif isinstance(arg, ParamAttr):
            return arg
180
        elif isinstance(arg, six.string_types):
Y
Yu Yang 已提交
181 182 183 184 185 186
            return ParamAttr(name=arg)
        elif isinstance(arg, Initializer):
            return ParamAttr(initializer=arg)
        elif isinstance(arg, WeightDecayRegularizer):
            return ParamAttr(regularizer=arg)
        elif isinstance(arg, bool):
Y
yuyang18 已提交
187
            return ParamAttr._to_attr(None) if arg else False
Y
Yu Yang 已提交
188 189 190
        else:
            raise TypeError("{0} cast to ParamAttr".format(type(arg)))

Y
yuyang18 已提交
191
    def _to_kwargs(self, with_initializer=False):
C
chengduoZH 已提交
192 193 194 195 196 197 198 199 200
        """
        Returns the attributes of this parameter.

        Args:
            with_initializer(bool): Whether to add initializer attr.

        Returns:
            Parameter attributes(map): The attributes of this parameter.
        """
Y
Yu Yang 已提交
201 202
        kwargs = {
            'name': self.name,
G
guosheng 已提交
203 204 205
            'optimize_attr': {
                'learning_rate': self.learning_rate
            },
Y
Yu Yang 已提交
206
            'regularizer': self.regularizer,
Y
Yu Yang 已提交
207
            'trainable': self.trainable,
208 209
            'do_model_average': self.do_model_average,
            'need_clip': self.need_clip
Y
Yu Yang 已提交
210 211 212 213
        }
        if with_initializer:
            kwargs['initializer'] = self.initializer
        return kwargs
G
guosheng 已提交
214 215 216


class WeightNormParamAttr(ParamAttr):
217
    r"""
S
swtkiwi 已提交
218

219 220
    Note:
        Please use 'paddle.nn.utils.weight_norm' in dygraph mode.
221 222 223 224 225 226 227
	
    Note:
        ``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
        There are three clipping strategies: :ref:`api_paddle_nn_ClipGradByGlobalNorm` , 
        :ref:`api_paddle_nn_ClipGradByNorm` , :ref:`api_paddle_nn_ClipGradByValue` .
	
228
    Parameter of weight Norm. Weight Norm is a reparameterization of the weight vectors
229
    in a neural network that decouples the magnitude of those weight vectors from
C
chengduoZH 已提交
230 231 232 233 234 235
    their direction. Weight Norm has been implemented as discussed in this
    paper: `Weight Normalization: A Simple Reparameterization to Accelerate
    Training of Deep Neural Networks
    <https://arxiv.org/pdf/1602.07868.pdf>`_.

    Args:
236
        dim(int, optional): Dimension over which to compute the norm. Dim is a non-negative
237
            number which is less than the rank of weight Tensor. For Example, dim can
T
tianshuo78520a 已提交
238
            be chosen from 0, 1, 2, 3 for convolution whose weight shape is [cout, cin, kh, kw]
239 240 241
            and rank is 4. Default None, meaning that all elements will be normalized.
        name(str, optional): The parameter's name. Default None, meaning that the name would
            be created automatically. Please refer to :ref:`api_guide_Name` for more details.
242 243
        initializer(Initializer, optional): The method to initialize this parameter, such as
            ``initializer = paddle.nn.initializer.Constant(1.0)``. Default None,
244 245
            meaning that the weight parameter is initialized by Xavier initializer, and
            the bias parameter is initialized by 0.
246
        learning_rate(float32, optional): The parameter's learning rate when
247
            optimizer is :math:`global\_lr * parameter\_lr * scheduler\_factor`.
X
Xin Pan 已提交
248
            Default 1.0.
249
        regularizer (WeightDecayRegularizer, optional): Regularization strategy. There are
250 251
            two method: :ref:`api_paddle_regularizer_L1Decay` ,
            :ref:`api_paddle_regularizer_L2Decay`.
252 253 254
            If regularizer isralso set in ``optimizer``
            (such as :ref:`api_paddle_optimizer_SGD` ), that regularizer setting in
            optimizer will be ignored. Default None, meaning there is no regularization.
255 256
        trainable(bool, optional): Whether this parameter is trainable. Default True.
        do_model_average(bool, optional): Whether this parameter should do model average.
X
Xin Pan 已提交
257
            Default False.
258
        need_clip (bool, optional): Whether the parameter gradient need to be cliped in optimizer. Default is True.
C
chengduoZH 已提交
259 260

    Examples:
261
    
C
chengduoZH 已提交
262
        .. code-block:: python
263
            
264 265 266 267 268 269
            import paddle

            paddle.enable_static()

            data = paddle.static.data(name="data", shape=[3, 32, 32], dtype="float32")

270
            fc = paddle.static.nn.fc(x=data,
271
                                     size=1000,
272 273 274 275 276 277 278 279 280
                                     weight_attr=paddle.static.WeightNormParamAttr(
                                         dim=None,
                                         name='weight_norm_param',
                                         initializer=paddle.nn.initializer.Constant(1.0),
                                         learning_rate=1.0,
                                         regularizer=paddle.regularizer.L2Decay(0.1),
                                         trainable=True,
                                         do_model_average=False,
                                         need_clip=True))
C
chengduoZH 已提交
281

G
guosheng 已提交
282 283 284
    """
    # List to record the parameters reparameterized by weight normalization.
    # If these parameters are treated as Variable rather than Parameter,
285
    # it can be used to discriminate these parameters and help to serialize
G
guosheng 已提交
286 287 288
    # these paramters for inference.
    params_with_weight_norm = []

X
Xin Pan 已提交
289 290 291 292 293 294 295
    def __init__(self,
                 dim=None,
                 name=None,
                 initializer=None,
                 learning_rate=1.0,
                 regularizer=None,
                 trainable=True,
296 297
                 do_model_average=False,
                 need_clip=True):
X
Xin Pan 已提交
298 299 300 301 302 303
        super(WeightNormParamAttr, self).__init__(
            name=name,
            initializer=initializer,
            learning_rate=learning_rate,
            regularizer=regularizer,
            trainable=trainable,
304 305
            do_model_average=do_model_average,
            need_clip=need_clip)
G
guosheng 已提交
306
        self.dim = dim