clip.py 16.3 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
update  
fengjiayi 已提交
14

15 16
from __future__ import print_function

F
fengjiayi 已提交
17
import copy
18
import six
F
fengjiayi 已提交
19

Y
Yu Yang 已提交
20
import functools
21 22
from . import layers
from . import framework
F
fengjiayi 已提交
23
from . import core
Y
Yu Yang 已提交
24

F
fengjiayi 已提交
25
__all__ = [
26
    'set_gradient_clip',
27
    'ErrorClipByValue',
F
fengjiayi 已提交
28 29 30
    'GradientClipByValue',
    'GradientClipByNorm',
    'GradientClipByGlobalNorm',
F
fengjiayi 已提交
31
]
Y
Yu Yang 已提交
32 33


F
fengjiayi 已提交
34
class BaseErrorClipAttr(object):
F
fengjiayi 已提交
35 36 37
    def __str__(self):
        raise NotImplementedError()

Y
yuyang18 已提交
38
    def _append_clip_op(self, block, grad_name):
F
fengjiayi 已提交
39 40 41 42
        raise NotImplementedError()


class ErrorClipByValue(BaseErrorClipAttr):
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
    """
    Clips tensor values to the range [min, max].

    Given a tensor t, this operation clips its value to min and max inplace.

    - Any values less than min are set to min.
    - Any values greater than max are set to max.

    Args:
        max (float): The maximum value to clip by.
        min (float, optional): The minimum value to clip by. if not set by user, \
        will be set to -max by framework.

    Examples:
        .. code-block:: python

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
            import paddle.fluid as fluid
            BATCH_SIZE = 128
            CLIP_MAX = 2e-6
            CLIP_MIN = -1e-6
            prog = fluid.framework.Program()
            with fluid.program_guard(main_program=prog):
                image = fluid.layers.data(name='x', shape=[784], dtype='float32')
                hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
                hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
                predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
                label = fluid.layers.data(name='y', shape=[1], dtype='int64')
                cost = fluid.layers.cross_entropy(input=predict, label=label)
                avg_cost = fluid.layers.mean(cost)
            prog_clip = prog.clone()
            prog_clip.block(0).var(hidden1.name)._set_error_clip(
                fluid.clip.ErrorClipByValue(
                    max=CLIP_MAX, min=CLIP_MIN)
76 77
    """

F
fengjiayi 已提交
78 79 80 81 82 83 84 85 86
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
87 88 89
    def __str__(self):
        return "ByValue, min=%f, max=%f" % (self.min, self.max)

Y
yuyang18 已提交
90
    def _append_clip_op(self, block, grad_name):
91 92 93 94
        clip_op_desc = block.desc.append_op()
        clip_op_desc.set_type("clip")
        clip_op_desc.set_input("X", [grad_name])
        clip_op_desc.set_output("Out", [grad_name])
W
Wu Yi 已提交
95 96
        clip_op_desc._set_attr("min", self.min)
        clip_op_desc._set_attr("max", self.max)
F
fengjiayi 已提交
97 98 99 100 101 102


def error_clip_callback(block, context):
    # the context is a grad_to_var map
    grad_to_var = context
    op_desc = block.desc.op(block.desc.op_size() - 1)
103
    for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
W
Wu Yi 已提交
104
        fwd_var = block._var_recursive(grad_to_var[grad_n])
F
fengjiayi 已提交
105
        error_clip = getattr(fwd_var, "error_clip", None)
F
fengjiayi 已提交
106 107 108 109 110
        if not (error_clip is None or isinstance(error_clip,
                                                 BaseErrorClipAttr)):
            raise TypeError(
                "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
            )
F
fengjiayi 已提交
111
        if error_clip is not None:
Y
yuyang18 已提交
112
            error_clip._append_clip_op(block, grad_n)
F
fengjiayi 已提交
113 114


Y
Yu Yang 已提交
115
class BaseGradientClipAttr(object):
F
fengjiayi 已提交
116 117 118
    def __str__(self):
        raise NotImplementedError()

Y
yuyang18 已提交
119
    def _process_context(self, context, param, grad):
Y
Yu Yang 已提交
120 121
        raise NotImplementedError()

Y
yuyang18 已提交
122
    def _create_operators(self, param, grad):
Y
Yu Yang 已提交
123 124 125 126
        raise NotImplementedError()


class NullGradientClipAttr(BaseGradientClipAttr):
F
fengjiayi 已提交
127 128 129
    def __str__(self):
        return "Null"

Y
yuyang18 已提交
130
    def _process_context(self, context, param, grad):
Y
Yu Yang 已提交
131 132
        pass

Y
yuyang18 已提交
133
    def _create_operators(self, param, grad):
Y
Yu Yang 已提交
134 135 136 137
        return param, grad


class GradientClipByValue(BaseGradientClipAttr):
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
    """
    Clips gradient values to the range [min, max].

    Given a tensor t, this operation clips its value to min and max inplace.

    - Any values less than min are set to min.
    - Any values greater than max are set to max.

    Args:
        max (float): The maximum value to clip by.
        min (float, optional): The minimum value to clip by. if not set by user, \
        will be set to -max by framework.

    Examples:
        .. code-block:: python

154
            import paddle.fluid as fluid
T
Tink_Y 已提交
155 156
            w_param_attrs = fluid.ParamAttr(name=None,
              initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0),
157
              learning_rate=1.0,
T
Tink_Y 已提交
158
              regularizer=fluid.regularizer.L1Decay(1.0),
159
              trainable=True,
160 161
              gradient_clip=fluid.clip.GradientClipByValue(-1.0, 1.0))
            x = fluid.layers.data(name='x', shape=[10], dtype='float32')
162 163 164
            y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)
    """

Y
Yu Yang 已提交
165 166 167 168 169 170 171 172 173
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
174 175 176
    def __str__(self):
        return "ByValue, min=%f, max=%f" % (self.min, self.max)

Y
yuyang18 已提交
177
    def _process_context(self, context, param, grad):
Y
Yu Yang 已提交
178 179
        pass

Y
yuyang18 已提交
180
    def _create_operators(self, param, grad):
Y
Yu Yang 已提交
181 182 183 184
        new_grad = layers.clip(x=grad, min=self.min, max=self.max)
        return param, new_grad


F
fengjiayi 已提交
185
class GradientClipByNorm(BaseGradientClipAttr):
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
    """
    Clips tensor values to a maximum L2-norm.

    This operator limits the L2 norm of the input :math:`X` within :math:`max\_norm`.
    If the L2 norm of :math:`X` is less than or equal to :math:`max\_norm`, :math:`Out`
    will be the same as :math:`X`. If the L2 norm of :math:`X` is greater than
    :math:`max\_norm`, :math:`X` will be linearly scaled to make the L2 norm of
    :math:`Out` equal to :math:`max\_norm`, as shown in the following formula:

    .. math::

        Out = \\frac{max\_norm * X}{norm(X)},

    where :math:`norm(X)` represents the L2 norm of :math:`X`.

    Args:
        clip_norm (float): The maximum norm value

    Examples:
        .. code-block:: python

207 208
            import paddle.fluid as fluid
            w_param_attrs = fluid.ParamAttr(name=None,
T
Tink_Y 已提交
209
              initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0),
210
              learning_rate=1.0,
T
Tink_Y 已提交
211
              regularizer=fluid.regularizer.L1Decay(1.0),
212
              trainable=True,
213 214
              gradient_clip=fluid.clip.GradientClipByNorm(clip_norm=2.0))
            x = fluid.layers.data(name='x', shape=[10], dtype='float32')
215 216 217 218
            y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)

    """

F
fengjiayi 已提交
219 220 221
    def __init__(self, clip_norm):
        self.clip_norm = clip_norm

F
fengjiayi 已提交
222 223 224
    def __str__(self):
        return "ByNorm, clip_norm=%f" % self.clip_norm

Y
yuyang18 已提交
225
    def _process_context(self, context, param, grad):
F
fengjiayi 已提交
226 227
        pass

Y
yuyang18 已提交
228
    def _create_operators(self, param, grad):
F
fengjiayi 已提交
229 230 231 232
        new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
        return param, new_grad


F
fengjiayi 已提交
233
class GradientClipByGlobalNorm(BaseGradientClipAttr):
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
    """
    Clips values of multiple tensors by the ratio of the sum of their norms.

    Given a list of tensors t_list, and a clipping ratio clip_norm, this
    operation returns a list of clipped tensors list_clipped and the global
    norm (global_norm) of all tensors in t_list.

    To perform the clipping, the values :math:`t\_list[i]` are set to:

    .. math::

        t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)}

    where:

    .. math::

        global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}

    If :math:`clip\_norm > global\_norm` then the entries in t_list remain as they are,
    otherwise they're all shrunk by the global ratio.

    Args:
        clip_norm (float): The maximum norm value
        group_name (str, optional): The group name for this clip.

    Examples:
        .. code-block:: python

263 264 265 266 267 268 269 270 271 272 273 274 275 276
            import paddle.fluid as fluid
            prog = fluid.framework.Program()
            startup_program = fluid.framework.Program()
            with fluid.program_guard(
                    main_program=prog, startup_program=startup_program):
                image = fluid.layers.data(name='x', shape=[784], dtype='float32')
                label = fluid.layers.data(name='y', shape=[1], dtype='int64')
                hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
                hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
                predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
                cost = fluid.layers.cross_entropy(input=predict, label=label)
                avg_cost = fluid.layers.mean(cost)
            prog_clip = prog.clone()
            avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
277 278 279 280 281 282 283 284 285
            p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)

            with fluid.program_guard(main_program=prog_clip):
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
                p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)

    """

F
update  
fengjiayi 已提交
286
    def __init__(self, clip_norm, group_name="default_group"):
287 288
        if not isinstance(group_name, six.string_types):
            raise TypeError("'group_name' must be a %s." % (six.string_types))
F
update  
fengjiayi 已提交
289 290 291

        self.clip_norm = clip_norm
        self.group_name = group_name
292

F
fengjiayi 已提交
293 294 295 296
    def __str__(self):
        return "ByGlobalNorm, group_name=%s, clip_norm=%f" % (self.group_name,
                                                              self.clip_norm)

Y
yuyang18 已提交
297
    def _process_context(self, context, param, grad):
F
update  
fengjiayi 已提交
298 299 300 301 302 303 304 305 306 307
        if self.group_name not in context:
            context[self.group_name] = []
            context[self.group_name + "_clip_value"] = self.clip_norm
            context[self.group_name + "_clip"] = layers.fill_constant(
                shape=[1], dtype="float32", value=self.clip_norm)
        else:
            if not self.clip_norm == context[self.group_name + "_clip_value"]:
                raise ValueError(
                    "All parameters' 'clip_norm' of a same group should be the same"
                )
F
fengjiayi 已提交
308

C
chengduo 已提交
309 310 311 312 313 314
        merge_grad = grad
        if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
            merge_grad = layers.merge_selected_rows(grad)
            merge_grad = layers.get_tensor_from_selected_rows(merge_grad)

        square = layers.square(merge_grad)
P
phlrain 已提交
315
        local_norm_var = layers.reduce_sum(input=square)
F
update  
fengjiayi 已提交
316
        context[self.group_name].append(local_norm_var)
F
fengjiayi 已提交
317

F
update  
fengjiayi 已提交
318
        self.context = context
319

Y
yuyang18 已提交
320
    def _create_operators(self, param, grad):
F
update  
fengjiayi 已提交
321 322 323
        group_scale_name = self.group_name + "_scale"
        if group_scale_name not in self.context:
            group_norm_var = layers.sums(input=self.context[self.group_name])
T
tensor-tang 已提交
324
            group_norm_var = layers.sqrt(x=group_norm_var)
F
update  
fengjiayi 已提交
325 326 327
            clip_var = self.context[self.group_name + "_clip"]
            group_scale_var = layers.elementwise_div(
                x=clip_var,
F
fengjiayi 已提交
328
                y=layers.elementwise_max(
F
update  
fengjiayi 已提交
329
                    x=clip_var, y=group_norm_var))
330
            assert group_scale_var.shape == (1, )
F
update  
fengjiayi 已提交
331
            self.context[group_scale_name] = group_scale_var
F
fengjiayi 已提交
332

F
update  
fengjiayi 已提交
333 334
        new_grad = layers.elementwise_mul(
            x=grad, y=self.context[group_scale_name])
C
chengduo 已提交
335

336
        return param, new_grad
F
fengjiayi 已提交
337 338


339
@framework.dygraph_not_support
F
fengjiayi 已提交
340
def set_gradient_clip(clip, param_list=None, program=None):
F
fengjiayi 已提交
341
    """
342 343 344
    To specify parameters that require gradient clip.

    Args:
Z
Zeng Jinle 已提交
345
        clip (BaseGradientClipAttr): An instance of some derived class of BaseGradientClipAttr,
346
                for example :ref:`api_fluid_clip_GradientClipByGlobalNorm` ,
347
                which describes the type and detailed attributes of required gradient clip.
Z
Zeng Jinle 已提交
348
        param_list (list(Variable), optional): Parameters that require gradient clip.
349
                It can be a list of parameter or a list of parameter's name.
350
                Default None, meaning that all parameters in the program will be included.
Z
Zeng Jinle 已提交
351
        program (Program, optional): The program where parameters are located.
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
                Default None, meaning that using :ref:`api_fluid_default_main_program` .

    Returns:
        None

    Examples:
        .. code-block:: python
            
            import paddle.fluid as fluid

            def network():
                image = fluid.layers.data(name='image', shape=[28], dtype='float32')
                param_attr1 = fluid.ParamAttr("fc1_param")
                fc1 = fluid.layers.fc(image, size=10, param_attr=param_attr1)
                param_attr2 = fluid.ParamAttr("fc2_param")
                fc2 = fluid.layers.fc(fc1, size=10, param_attr=param_attr2)
                loss = fluid.layers.reduce_mean(fc2)
                return loss


            # network 1: clip all parameter gradient
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

            # network 2: clip parameter gradient by name
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=["fc1_param", "fc2_param"])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

            # network 3: clip parameter gradient by var
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                param_var1 = fluid.default_main_program().global_block().var("fc1_param")
                param_var2 = fluid.default_main_program().global_block().var("fc2_param")
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=[param_var1, param_var2])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)
F
fengjiayi 已提交
399
    """
F
fengjiayi 已提交
400 401 402 403
    if not isinstance(clip, BaseGradientClipAttr):
        raise TypeError(
            "'clip' should be an instance of BaseGradientClipAttr's derived class"
        )
F
fengjiayi 已提交
404 405 406 407
    if program is None:
        program = framework.default_main_program()
    if param_list is None:
        param_list = program.block(0).all_parameters()
408
    if all(isinstance(elem, six.string_types) for elem in param_list):
F
fengjiayi 已提交
409 410 411 412 413 414 415
        param_list = [program.block(0).var(elem) for elem in param_list]
    if not all(isinstance(elem, framework.Parameter) for elem in param_list):
        raise TypeError(
            "'param_list' should be a list of Parameter or basestring(parameter's name)."
        )

    for param in param_list:
F
fengjiayi 已提交
416
        param.gradient_clip_attr = copy.deepcopy(clip)
F
fengjiayi 已提交
417 418


419
def append_gradient_clip_ops(param_grads):
Y
Yu Yang 已提交
420
    context = dict()
421 422 423
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
424 425
        with p.block.program._optimized_guard(
            [p, g]), framework.name_scope('append_clip'):
Y
yuyang18 已提交
426 427 428 429 430 431 432
            clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
            if clip_attr is None:
                clip_attr = NullGradientClipAttr()
            if not isinstance(clip_attr, BaseGradientClipAttr):
                raise TypeError(
                    "clip attribute should be an instance of BaseGradientClipAttr"
                )
Y
Yu Yang 已提交
433

Y
yuyang18 已提交
434
            clip_attr._process_context(context=context, param=p, grad=g)
Y
yuyang18 已提交
435 436

    res = []
437 438 439
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
440 441
        with p.block.program._optimized_guard(
            [p, g]), framework.name_scope('append_graident_clip'):
Y
yuyang18 已提交
442
            res.append(clip_attr._create_operators(param=p, grad=g))
Y
Yu Yang 已提交
443

Y
yuyang18 已提交
444
    return res
Y
Yu Yang 已提交
445 446 447


ClipByValue = GradientClipByValue
F
fengjiayi 已提交
448 449
ClipByNorm = GradientClipByNorm
ClipByGlobalNorm = GradientClipByGlobalNorm