param_attr.py 12.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
update  
fengjiayi 已提交
14

15 16
from __future__ import print_function

17
import six
18
import warnings
19
import sys
20

21 22
from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer
23
from paddle.fluid.data_feeder import check_type
Y
Yu Yang 已提交
24

25 26 27 28
__all__ = [
    'ParamAttr',
    'WeightNormParamAttr',
]
Y
Yu Yang 已提交
29

Y
Yu Yang 已提交
30 31

class ParamAttr(object):
C
chengduoZH 已提交
32
    """
Z
Zeng Jinle 已提交
33 34 35
    Create a object to represent the attribute of parameter. The attributes are:
    name, initializer, learning rate, regularizer, trainable, gradient clip,
    and model average.
36 37 38
    
    Note:
        ``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0. 
39
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
40 41
        There are three clipping strategies: :ref:`api_paddle_nn_ClipGradByGlobalNorm` , 
        :ref:`api_paddle_nn_ClipGradByNorm` , :ref:`api_paddle_nn_ClipGradByValue` .
Z
Zeng Jinle 已提交
42 43 44 45 46 47 48 49 50 51

    Parameters:
        name (str, optional): The parameter's name. Default None, meaning that the name
                would be created automatically.
        initializer (Initializer, optional): The method to initial this parameter. Default
                None, meaning that the weight parameter is initialized by Xavier initializer,
                and the bias parameter is initialized by 0.
        learning_rate (float): The parameter's learning rate. The learning rate when
                optimize is the global learning rates times the parameter's learning rate times
                the factor of learning rate scheduler. Default 1.0.
52
        regularizer (WeightDecayRegularizer, optional): Regularization strategy. There are two method: 
53 54
                :ref:`api_paddle_regularizer_L1Decay` , :ref:`api_paddle_regularizer_L2Decay` . If 
                regularizer is also set in ``optimizer`` (such as :ref:`api_paddle_optimizer_SGD` ), 
55 56
                that regularizer setting in optimizer will be ignored. Default None, meaning there is 
                no regularization.
Z
Zeng Jinle 已提交
57 58
        trainable (bool): Whether this parameter is trainable. Default True.
        do_model_average (bool): Whether this parameter should do model average
59
                when model average is enabled. Only used in ExponentialMovingAverage. Default True.
60
        need_clip (bool): Whether the parameter gradient need to be cliped in optimizer. Default is True.
C
chengduoZH 已提交
61 62 63 64

    Examples:
        .. code-block:: python

65 66 67 68 69 70 71 72
            import paddle

            weight_attr = paddle.ParamAttr(name="weight",
                                           learning_rate=0.5,
                                           regularizer=paddle.regularizer.L2Decay(1.0),
                                           trainable=True)
            print(weight_attr.name) # "weight"
            paddle.nn.Linear(3, 4, weight_attr=weight_attr)
C
chengduoZH 已提交
73 74
    """

Y
Yu Yang 已提交
75 76 77 78 79
    def __init__(self,
                 name=None,
                 initializer=None,
                 learning_rate=1.0,
                 regularizer=None,
Y
Yu Yang 已提交
80
                 trainable=True,
81 82
                 do_model_average=True,
                 need_clip=True):
83 84 85 86 87 88 89 90

        if sys.version_info.major == 2:
            check_type(name, "name", (str, type(None), unicode), "ParamAttr")
        else:
            check_type(name, "name", (str, type(None)), "ParamAttr")
        check_type(learning_rate, "learning_rate", (float, int), "ParamAttr")
        check_type(trainable, "trainable", (bool), "ParamAttr")
        check_type(do_model_average, "do_model_average", (bool), "ParamAttr")
91
        check_type(need_clip, "need_clip", (bool), "ParamAttr")
92 93 94 95
        check_type(initializer, "initializer", (Initializer, type(None)),
                   "ParamAttr")
        check_type(regularizer, "regularizer",
                   (WeightDecayRegularizer, type(None)), "ParamAttr")
96

Y
Yu Yang 已提交
97
        self.name = name
98
        if self.name == "":
H
hong 已提交
99 100
            raise ValueError("name of ParamAttr can not be empty str")

Y
Yu Yang 已提交
101 102 103 104
        self.initializer = initializer
        self.learning_rate = learning_rate
        self.regularizer = regularizer
        self.trainable = trainable
105
        self.do_model_average = do_model_average
106
        self.need_clip = need_clip
Y
Yu Yang 已提交
107

Y
yuyang18 已提交
108
    def _set_default_initializer(self, initializer):
C
chengduoZH 已提交
109 110 111
        """
        Set the default initializer, the initializer should be Constant,
        Uniform, Normal, Xavier, MSRA.
C
chengduoZH 已提交
112 113 114 115 116 117

        Args:
            initializer(Initializer): the initializer to set.

        Returns:
            None
C
chengduoZH 已提交
118
        """
Y
Yu Yang 已提交
119 120 121 122 123 124 125 126 127 128
        if initializer is None:
            if self.initializer is None:
                raise ValueError("ParamAttr.initializer is not set")
            return

        if self.initializer is not None:
            return

        self.initializer = initializer

Y
yuyang18 已提交
129
    def _set_default_param_initializer(self):
C
chengduoZH 已提交
130 131
        """
        Set the default initializer for the parameter with Xavier.
C
chengduoZH 已提交
132 133 134 135 136 137

        Args:
            None.

        Returns:
            None.
C
chengduoZH 已提交
138
        """
Y
yuyang18 已提交
139
        self._set_default_initializer(Xavier())
Y
Yu Yang 已提交
140

Y
yuyang18 已提交
141
    def _set_default_bias_initializer(self):
C
chengduoZH 已提交
142 143
        """
        Set the default initializer for the bias with Constant(0.0).
C
chengduoZH 已提交
144 145 146 147 148 149

        Args:
            None.

        Returns:
            None.
C
chengduoZH 已提交
150
        """
Y
yuyang18 已提交
151
        self._set_default_initializer(Constant(0.0))
Y
Yu Yang 已提交
152 153

    @staticmethod
Y
yuyang18 已提交
154
    def _to_attr(arg):
C
chengduoZH 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168
        """
        Create ParamAttr[s].

        Args:
            arg: Arguments to initialize ParamAttr[s]. arg's type can be
                str, Initializer, float, WeightDecayRegularizer, BaseGradientClipAttr,
                bool, ParamAttr, or a list of above type.

        Returns:
            ParamAttr[s]: ParamAttr[s] initialized with arg.

        Raises:
            arg can not initialize a ParamAttr.
        """
Y
Yu Yang 已提交
169 170
        if arg is None:
            return ParamAttr()
171
        elif isinstance(arg, list) or isinstance(arg, tuple):
Y
yuyang18 已提交
172
            return [ParamAttr._to_attr(a) for a in arg]
Y
Yu Yang 已提交
173 174
        elif isinstance(arg, ParamAttr):
            return arg
175
        elif isinstance(arg, six.string_types):
Y
Yu Yang 已提交
176 177 178 179 180 181
            return ParamAttr(name=arg)
        elif isinstance(arg, Initializer):
            return ParamAttr(initializer=arg)
        elif isinstance(arg, WeightDecayRegularizer):
            return ParamAttr(regularizer=arg)
        elif isinstance(arg, bool):
Y
yuyang18 已提交
182
            return ParamAttr._to_attr(None) if arg else False
Y
Yu Yang 已提交
183 184 185
        else:
            raise TypeError("{0} cast to ParamAttr".format(type(arg)))

Y
yuyang18 已提交
186
    def _to_kwargs(self, with_initializer=False):
C
chengduoZH 已提交
187 188 189 190 191 192 193 194 195
        """
        Returns the attributes of this parameter.

        Args:
            with_initializer(bool): Whether to add initializer attr.

        Returns:
            Parameter attributes(map): The attributes of this parameter.
        """
Y
Yu Yang 已提交
196 197
        kwargs = {
            'name': self.name,
G
guosheng 已提交
198 199 200
            'optimize_attr': {
                'learning_rate': self.learning_rate
            },
Y
Yu Yang 已提交
201
            'regularizer': self.regularizer,
Y
Yu Yang 已提交
202
            'trainable': self.trainable,
203 204
            'do_model_average': self.do_model_average,
            'need_clip': self.need_clip
Y
Yu Yang 已提交
205 206 207 208
        }
        if with_initializer:
            kwargs['initializer'] = self.initializer
        return kwargs
G
guosheng 已提交
209 210 211


class WeightNormParamAttr(ParamAttr):
212
    r"""
213
	:api_attr: Static Graph
S
swtkiwi 已提交
214

215 216 217
    Note:
        Please use 'paddle.nn.utils.weight_norm' in dygraph mode.

218
    Parameter of weight Norm. Weight Norm is a reparameterization of the weight vectors
219
    in a neural network that decouples the magnitude of those weight vectors from
C
chengduoZH 已提交
220 221 222 223
    their direction. Weight Norm has been implemented as discussed in this
    paper: `Weight Normalization: A Simple Reparameterization to Accelerate
    Training of Deep Neural Networks
    <https://arxiv.org/pdf/1602.07868.pdf>`_.
224 225
      
    Note:
226 227
        ``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
228 229
        There are three clipping strategies: :ref:`api_paddle_nn_ClipGradByGlobalNorm` , 
        :ref:`api_paddle_nn_ClipGradByNorm` , :ref:`api_paddle_nn_ClipGradByValue` .
230
        
C
chengduoZH 已提交
231 232

    Args:
233
        dim(int, optional): Dimension over which to compute the norm. Dim is a non-negative
234
            number which is less than the rank of weight Tensor. For Example, dim can
T
tianshuo78520a 已提交
235
            be chosen from 0, 1, 2, 3 for convolution whose weight shape is [cout, cin, kh, kw]
236 237 238
            and rank is 4. Default None, meaning that all elements will be normalized.
        name(str, optional): The parameter's name. Default None, meaning that the name would
            be created automatically. Please refer to :ref:`api_guide_Name` for more details.
239 240
        initializer(Initializer, optional): The method to initialize this parameter, such as
            ``initializer = paddle.nn.initializer.Constant(1.0)``. Default None,
241 242
            meaning that the weight parameter is initialized by Xavier initializer, and
            the bias parameter is initialized by 0.
243
        learning_rate(float32, optional): The parameter's learning rate when
244
            optimizer is :math:`global\_lr * parameter\_lr * scheduler\_factor`.
X
Xin Pan 已提交
245
            Default 1.0.
246
        regularizer (WeightDecayRegularizer, optional): Regularization strategy. There are
247 248
            two method: :ref:`api_paddle_regularizer_L1Decay` ,
            :ref:`api_paddle_regularizer_L2Decay`.
249 250 251
            If regularizer isralso set in ``optimizer``
            (such as :ref:`api_paddle_optimizer_SGD` ), that regularizer setting in
            optimizer will be ignored. Default None, meaning there is no regularization.
252 253
        trainable(bool, optional): Whether this parameter is trainable. Default True.
        do_model_average(bool, optional): Whether this parameter should do model average.
X
Xin Pan 已提交
254
            Default False.
255
        need_clip (bool, optional): Whether the parameter gradient need to be cliped in optimizer. Default is True.
C
chengduoZH 已提交
256 257 258

    Examples:
        .. code-block:: python
259
            
260 261 262 263 264 265
            import paddle

            paddle.enable_static()

            data = paddle.static.data(name="data", shape=[3, 32, 32], dtype="float32")

266
            fc = paddle.static.nn.fc(x=data,
267
                                     size=1000,
268 269 270 271 272 273 274 275 276
                                     weight_attr=paddle.static.WeightNormParamAttr(
                                         dim=None,
                                         name='weight_norm_param',
                                         initializer=paddle.nn.initializer.Constant(1.0),
                                         learning_rate=1.0,
                                         regularizer=paddle.regularizer.L2Decay(0.1),
                                         trainable=True,
                                         do_model_average=False,
                                         need_clip=True))
C
chengduoZH 已提交
277

G
guosheng 已提交
278 279 280
    """
    # List to record the parameters reparameterized by weight normalization.
    # If these parameters are treated as Variable rather than Parameter,
281
    # it can be used to discriminate these parameters and help to serialize
G
guosheng 已提交
282 283 284
    # these paramters for inference.
    params_with_weight_norm = []

X
Xin Pan 已提交
285 286 287 288 289 290 291
    def __init__(self,
                 dim=None,
                 name=None,
                 initializer=None,
                 learning_rate=1.0,
                 regularizer=None,
                 trainable=True,
292 293
                 do_model_average=False,
                 need_clip=True):
X
Xin Pan 已提交
294 295 296 297 298 299
        super(WeightNormParamAttr, self).__init__(
            name=name,
            initializer=initializer,
            learning_rate=learning_rate,
            regularizer=regularizer,
            trainable=trainable,
300 301
            do_model_average=do_model_average,
            need_clip=need_clip)
G
guosheng 已提交
302
        self.dim = dim