clip.py 34.5 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
update  
fengjiayi 已提交
14

15 16
from __future__ import print_function

F
fengjiayi 已提交
17
import copy
18
import six
19
import warnings
F
fengjiayi 已提交
20

Y
Yu Yang 已提交
21
import functools
W
WangXi 已提交
22
import paddle
23 24
from . import layers
from . import framework
F
fengjiayi 已提交
25
from . import core
C
Chengmo 已提交
26
from . import name_scope
27
from .dygraph import base as imperative_base
W
WangXi 已提交
28
from .data_feeder import check_variable_and_dtype
J
Jiabin Yang 已提交
29
from .framework import _non_static_mode
W
WangXi 已提交
30
from .layer_helper import LayerHelper
31
from .framework import default_main_program
32
from paddle import _C_ops
Y
Yu Yang 已提交
33

F
fengjiayi 已提交
34
__all__ = [
35 36
    'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
    'ClipGradByNorm', 'ClipGradByGlobalNorm'
F
fengjiayi 已提交
37
]
Y
Yu Yang 已提交
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
_clip_by_global_norm_using_mp_type_flag = False


def _clip_by_global_norm_using_mp_type(*args):
    global _clip_by_global_norm_using_mp_type_flag
    assert len(args) <= 1
    if len(args) == 1:
        assert isinstance(args[0], bool)
        old_value = _clip_by_global_norm_using_mp_type_flag
        _clip_by_global_norm_using_mp_type_flag = args[0]
        return old_value
    else:
        return _clip_by_global_norm_using_mp_type_flag


def _cast_to_mp_type_if_enabled(x):
    if x.dtype == core.VarDesc.VarType.FP16 and _clip_by_global_norm_using_mp_type(
    ):
        return x.astype(core.VarDesc.VarType.FP32)
    else:
        return x

Y
Yu Yang 已提交
61

W
WangXi 已提交
62 63 64 65 66
def _squared_l2_norm(x):
    r"""
    This OP returns the squared L2 norm of a tensor.
    """

67
    x = _cast_to_mp_type_if_enabled(x)
68
    if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16:
W
WangXi 已提交
69 70 71 72
        square = layers.square(x)
        sum_square = layers.reduce_sum(square)
        return sum_square

J
Jiabin Yang 已提交
73
    if _non_static_mode():
74
        return _C_ops.squared_l2_norm(x)
W
WangXi 已提交
75 76

    op_type = 'squared_l2_norm'
77
    check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
W
WangXi 已提交
78 79 80 81 82 83 84 85 86
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(x.dtype)

    inputs = {"X": x}
    outputs = {'Out': out}
    helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
    return out


F
fengjiayi 已提交
87
class BaseErrorClipAttr(object):
F
fengjiayi 已提交
88 89 90
    def __str__(self):
        raise NotImplementedError()

Y
yuyang18 已提交
91
    def _append_clip_op(self, block, grad_name):
F
fengjiayi 已提交
92 93 94 95
        raise NotImplementedError()


class ErrorClipByValue(BaseErrorClipAttr):
96
    r"""
97 98
    Clips tensor values to the range [min, max].

99 100
    Given a tensor ``t`` (see Examples below), this operation clips its value \
    to ``min`` and ``max`` inplace.
101 102 103 104 105 106 107

    - Any values less than min are set to min.
    - Any values greater than max are set to max.

    Args:
        max (float): The maximum value to clip by.
        min (float, optional): The minimum value to clip by. if not set by user, \
108
        will be set to ``-max`` by framework.
109 110 111 112

    Examples:
        .. code-block:: python

113 114 115 116 117 118
            import paddle.fluid as fluid
            BATCH_SIZE = 128
            CLIP_MAX = 2e-6
            CLIP_MIN = -1e-6
            prog = fluid.framework.Program()
            with fluid.program_guard(main_program=prog):
C
Chengmo 已提交
119 120
                image = fluid.layers.data(
                    name='x', shape=[784], dtype='float32')
121 122
                hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
                hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
C
Chengmo 已提交
123 124
                predict = fluid.layers.fc(
                    input=hidden2, size=10, act='softmax')
125 126 127 128 129 130 131
                label = fluid.layers.data(name='y', shape=[1], dtype='int64')
                cost = fluid.layers.cross_entropy(input=predict, label=label)
                avg_cost = fluid.layers.mean(cost)
            prog_clip = prog.clone()
            prog_clip.block(0).var(hidden1.name)._set_error_clip(
                fluid.clip.ErrorClipByValue(
                    max=CLIP_MAX, min=CLIP_MIN)
132 133
    """

F
fengjiayi 已提交
134 135 136 137 138 139 140 141 142
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
143 144 145
    def __str__(self):
        return "ByValue, min=%f, max=%f" % (self.min, self.max)

Y
yuyang18 已提交
146
    def _append_clip_op(self, block, grad_name):
147 148 149 150
        clip_op_desc = block.desc.append_op()
        clip_op_desc.set_type("clip")
        clip_op_desc.set_input("X", [grad_name])
        clip_op_desc.set_output("Out", [grad_name])
W
Wu Yi 已提交
151 152
        clip_op_desc._set_attr("min", self.min)
        clip_op_desc._set_attr("max", self.max)
F
fengjiayi 已提交
153 154 155 156 157 158


def error_clip_callback(block, context):
    # the context is a grad_to_var map
    grad_to_var = context
    op_desc = block.desc.op(block.desc.op_size() - 1)
159
    for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
W
Wu Yi 已提交
160
        fwd_var = block._var_recursive(grad_to_var[grad_n])
F
fengjiayi 已提交
161
        error_clip = getattr(fwd_var, "error_clip", None)
F
fengjiayi 已提交
162 163 164 165 166
        if not (error_clip is None or isinstance(error_clip,
                                                 BaseErrorClipAttr)):
            raise TypeError(
                "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
            )
F
fengjiayi 已提交
167
        if error_clip is not None:
Y
yuyang18 已提交
168
            error_clip._append_clip_op(block, grad_n)
F
fengjiayi 已提交
169 170


171 172 173
class ClipGradBase(object):
    def __init__(self):
        super(ClipGradBase, self).__init__()
174

F
fengjiayi 已提交
175 176 177
    def __str__(self):
        raise NotImplementedError()

178
    @imperative_base.no_grad
179 180
    def _dygraph_clip(self, params_grads):
        raise NotImplementedError
Y
Yu Yang 已提交
181

182 183
    def _static_clip(self, params_grads):
        raise NotImplementedError
Y
Yu Yang 已提交
184

185
    def __call__(self, params_grads):
J
Jiabin Yang 已提交
186
        if framework._non_static_mode():
187 188 189 190 191 192
            return self._dygraph_clip(params_grads)
        else:
            for p, g in params_grads:
                if getattr(p, 'gradient_clip_attr', None) is not None:
                    warnings.warn(
                        "'set_gradient_clip' will be ineffective, because you have "
193
                        "set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' "
194 195 196
                        "is redundant and you can remove it.")
                    break
            return self._static_clip(params_grads)
F
fengjiayi 已提交
197

Y
yuyang18 已提交
198
    def _process_context(self, context, param, grad):
199
        raise NotImplementedError()
Y
Yu Yang 已提交
200

Y
yuyang18 已提交
201
    def _create_operators(self, param, grad):
202
        raise NotImplementedError()
Y
Yu Yang 已提交
203 204


205
class ClipGradByValue(ClipGradBase):
206
    """
207 208
    Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max].
    
209
    - Any values less than min are set to ``min``.
210
    
211
    - Any values greater than max are set to ``max``.
212

213 214
    The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``. 
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
215
    
216
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
217
    (for example: :ref:`api_paddle_optimizer_SGD`).
218 219 220 221

    Note:
        ``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
222
    
223 224
    Args:
        max (float): The maximum value to clip by.
225 226
        min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max`` 
            automatically. In this case, ``max`` must be greater than 0.
227 228 229

    Examples:
        .. code-block:: python
230 231
        
            import paddle
232

233
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
234 235 236
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
237 238 239 240
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

241
            clip = paddle.nn.ClipGradByValue(min=-1, max=1)
242 243
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
244 245
    """

246 247
    def __init__(self, max, min=None):
        super(ClipGradByValue, self).__init__()
Y
Yu Yang 已提交
248
        if min is None:
249
            assert (max > 0.0)
Y
Yu Yang 已提交
250
            min = -max
251 252
        self.max = float(max)
        self.min = float(min)
Y
Yu Yang 已提交
253

F
fengjiayi 已提交
254
    def __str__(self):
255
        return "Clip Gradient By Value, min = %f, max=%f" % (self.min, self.max)
256

257
    @imperative_base.no_grad
258 259 260 261 262
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        for p, g in params_grads:
            if g is None:
                continue
263
            if getattr(p, 'need_clip', True) is False:
264 265 266 267 268 269 270 271
                params_and_grads.append((p, g))
                continue
            new_grad = layers.clip(x=g, min=self.min, max=self.max)
            params_and_grads.append((p, new_grad))
        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
272
        param_new_grad_name_dict = dict()
273 274 275 276
        with framework.name_scope('gradient_clip'):
            for p, g in params_grads:
                if g is None:
                    continue
277
                if getattr(p, 'need_clip', True) is False:
278 279 280 281 282 283
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
                    new_grad = layers.clip(x=g, min=self.min, max=self.max)
                params_and_grads.append((p, new_grad))
284 285
                param_new_grad_name_dict[p.name] = new_grad.name
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
286
        return params_and_grads
F
fengjiayi 已提交
287

Y
yuyang18 已提交
288
    def _process_context(self, context, param, grad):
Y
Yu Yang 已提交
289 290
        pass

Y
yuyang18 已提交
291
    def _create_operators(self, param, grad):
Y
Yu Yang 已提交
292 293 294 295
        new_grad = layers.clip(x=grad, min=self.min, max=self.max)
        return param, new_grad


296
class ClipGradByNorm(ClipGradBase):
297
    r"""
298 299 300 301 302 303
    Limit the l2 norm of multi-dimensional Tensor :math:`X` to ``clip_norm`` .
    
    - If the l2 norm of :math:`X` is greater than ``clip_norm`` , :math:`X` will be compressed by a ratio.
    
    - If the l2 norm of :math:`X` is less than or equal to ``clip_norm`` , nothing will be done.
    
304 305
    The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
306
    
307
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
308
    (for example: :ref:`api_paddle_optimizer_SGD`).
309 310
    
    The clipping formula is:
311 312

    .. math::
313
        Out =
314 315 316 317 318 319
        \left\{
            \begin{array}{ccl}
                X & & if (norm(X) \leq clip\_norm) \\
                \frac{clip\_norm*X}{norm(X)} & & if (norm(X) > clip\_norm) \\
        \end{array}
        \right.
320 321 322 323


    where :math:`norm(X)` represents the L2 norm of :math:`X`.

324
    .. math::
325
        norm(X) = ( \sum_{i=1}^{n}|x\_i|^2)^{ \frac{1}{2}}
326

327 328 329 330
    Note:
        ``need_clip`` of ``ClipGradByNorm`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

331
    Args:
332
        clip_norm(float): The maximum norm value.
C
Chengmo 已提交
333

334 335
    Examples:
        .. code-block:: python
336 337
        
            import paddle
338

339
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
340 341 342
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
343 344 345 346
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

347
            clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
348 349
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
350 351
    """

352 353
    def __init__(self, clip_norm):
        super(ClipGradByNorm, self).__init__()
354
        self.clip_norm = float(clip_norm)
F
fengjiayi 已提交
355

F
fengjiayi 已提交
356
    def __str__(self):
357 358
        return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm

359
    @imperative_base.no_grad
360 361 362 363 364
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        for p, g in params_grads:
            if g is None:
                continue
365
            if getattr(p, 'need_clip', True) is False:
366 367 368 369 370 371 372 373 374
                params_and_grads.append((p, g))
                continue
            new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
            params_and_grads.append((p, new_grad))
        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
        with framework.name_scope('gradient_clip'):
375
            param_new_grad_name_dict = dict()
376 377 378
            for p, g in params_grads:
                if g is None:
                    continue
379
                if getattr(p, 'need_clip', True) is False:
380 381 382 383 384
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
                    new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
385
                param_new_grad_name_dict[p.name] = new_grad.name
386
                params_and_grads.append((p, new_grad))
387
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
388
        return params_and_grads
F
fengjiayi 已提交
389

Y
yuyang18 已提交
390
    def _process_context(self, context, param, grad):
F
fengjiayi 已提交
391 392
        pass

Y
yuyang18 已提交
393
    def _create_operators(self, param, grad):
F
fengjiayi 已提交
394 395 396 397
        new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
        return param, new_grad


398 399 400 401 402 403 404 405 406 407 408 409 410 411
_allow_pure_fp16_global_norm_clip_flag = False


def _allow_pure_fp16_global_norm_clip(*args):
    global _allow_pure_fp16_global_norm_clip_flag
    if len(args) == 0:
        return _allow_pure_fp16_global_norm_clip_flag
    else:
        assert len(args) == 1 and isinstance(args[0], bool)
        old_value = _allow_pure_fp16_global_norm_clip_flag
        _allow_pure_fp16_global_norm_clip_flag = args[0]
        return old_value


412
class ClipGradByGlobalNorm(ClipGradBase):
413
    r"""
414 415 416 417 418 419 420
    Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in 
    :math:`t\_list` , and limit it to ``clip_norm`` .
    
    - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
    
    - If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
    
421 422
    The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
423
    
424
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
425
    (for example: :ref:`api_paddle_optimizer_SGD`).
426 427

    The clipping formula is:
428 429 430

    .. math::

431
        t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}
432 433 434 435 436 437 438

    where:

    .. math::

        global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}

439 440 441 442
    Note:
        ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

443
    Args:
444
        clip_norm (float): The maximum norm value.
445
        group_name (str, optional): The group name for this clip. Default value is ``default_group``.
446 447 448

    Examples:
        .. code-block:: python
449
        
450 451
            import paddle

452
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
453 454 455
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
456 457 458 459
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

460
            clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
461 462
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
463 464
    """

465 466
    def __init__(self, clip_norm, group_name="default_group"):
        super(ClipGradByGlobalNorm, self).__init__()
467
        self.clip_norm = float(clip_norm)
F
update  
fengjiayi 已提交
468
        self.group_name = group_name
469

F
fengjiayi 已提交
470
    def __str__(self):
471 472
        return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)

473
    @imperative_base.no_grad
474 475 476
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        sum_square_list = []
477 478
        sum_square_list_fp16 = []
        sum_square_list_fp32 = []
479 480 481
        for p, g in params_grads:
            if g is None:
                continue
482
            if getattr(p, 'need_clip', True) is False:
483 484 485 486 487
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.merge_selected_rows(g)
                merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
W
WangXi 已提交
488 489

            sum_square = _squared_l2_norm(merge_grad)
490 491 492 493 494 495
            if sum_square.dtype == core.VarDesc.VarType.FP16:
                sum_square_list_fp16.append(sum_square)
            elif sum_square.dtype == core.VarDesc.VarType.FP32:
                sum_square_list_fp32.append(sum_square)
            else:
                sum_square_list.append(sum_square)
496 497

        # all parameters have been filterd out
498 499
        if len(sum_square_list) + len(sum_square_list_fp16) + len(
                sum_square_list_fp32) == 0:
500 501
            return params_grads

502 503 504
        sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
        global_norm_var = []
        if len(sum_square_list_fp16) > 0:
Z
zhangbo9674 已提交
505
            global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
506 507
            global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
        if len(sum_square_list_fp32) > 0:
Z
zhangbo9674 已提交
508
            global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
509 510 511 512 513
            if sum_dtype == 'float32':
                global_norm_var.append(global_norm_var_fp32)
            else:
                global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
        if len(sum_square_list) > 0:
Z
zhangbo9674 已提交
514
            global_norm_var_fp64 = paddle.add_n(sum_square_list)
515
            global_norm_var.append(global_norm_var_fp64)
Z
zhangbo9674 已提交
516
        global_norm_var = paddle.add_n(global_norm_var)
517 518
        global_norm_var = layers.sqrt(global_norm_var)
        max_global_norm = layers.fill_constant(
519
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
Z
zhangbo9674 已提交
520 521 522 523 524 525 526 527 528

        # only when global_norm_var > max_global_norm, grad need clip
        need_clip = False
        if global_norm_var > max_global_norm:
            need_clip = True

        if need_clip:
            clip_var = layers.elementwise_div(
                x=max_global_norm, y=global_norm_var)
529 530 531
        for p, g in params_grads:
            if g is None:
                continue
532
            if getattr(p, 'need_clip', True) is False:
533 534
                params_and_grads.append((p, g))
                continue
W
WangXi 已提交
535
            # TODO(wangxi): use inplace elementwise_mul
Z
zhangbo9674 已提交
536 537 538 539 540 541 542 543
            if need_clip:
                clip_input = (clip_var.astype('float16')
                              if g.dtype == core.VarDesc.VarType.FP16 else
                              clip_var)
                new_grad = layers.elementwise_mul(x=g, y=clip_input)
                params_and_grads.append((p, new_grad))
            else:
                params_and_grads.append((p, g))
544 545 546 547 548 549

        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
        sum_square_list = []
550 551
        sum_square_list_fp16 = []
        sum_square_list_fp32 = []
552 553 554 555
        with framework.name_scope('gradient_clip'):
            for p, g in params_grads:
                if g is None:
                    continue
556
                if getattr(p, 'need_clip', True) is False:
557 558 559 560 561 562 563
                    continue
                merge_grad = g
                with p.block.program._optimized_guard([p, g]):
                    if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                        merge_grad = layers.merge_selected_rows(g)
                        merge_grad = layers.get_tensor_from_selected_rows(
                            merge_grad)
W
WangXi 已提交
564
                    sum_square = _squared_l2_norm(merge_grad)
565 566 567 568 569 570
                    if sum_square.dtype == core.VarDesc.VarType.FP16:
                        sum_square_list_fp16.append(sum_square)
                    elif sum_square.dtype == core.VarDesc.VarType.FP32:
                        sum_square_list_fp32.append(sum_square)
                    else:
                        sum_square_list.append(sum_square)
571 572

            # all parameters have been filterd out
573 574
            if len(sum_square_list) + len(sum_square_list_fp16) + len(
                    sum_square_list_fp32) == 0:
575 576 577
                return params_grads

            with p.block.program._optimized_guard([p, g]):
578 579 580 581 582
                sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"

                global_norm_var = []
                if len(sum_square_list_fp16) > 0:
                    global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
583 584 585 586 587 588
                    if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip(
                    ):
                        global_norm_var.append(
                            global_norm_var_fp16.astype(sum_dtype))
                    else:
                        global_norm_var.append(global_norm_var_fp16)
589 590 591 592 593 594 595 596 597 598 599
                if len(sum_square_list_fp32) > 0:
                    global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
                    if sum_dtype == 'float32':
                        global_norm_var.append(global_norm_var_fp32)
                    else:
                        global_norm_var.append(
                            global_norm_var_fp32.astype(sum_dtype))
                if len(sum_square_list) > 0:
                    # fp64
                    global_norm_var_other_dtype = layers.sums(sum_square_list)
                    global_norm_var.append(global_norm_var_other_dtype)
600 601 602

                global_norm_var = layers.sums(global_norm_var) if len(
                    global_norm_var) > 1 else global_norm_var[0]
603 604
                global_norm_var = layers.sqrt(x=global_norm_var)
                max_global_norm = layers.fill_constant(
605 606 607
                    shape=[1],
                    dtype=global_norm_var.dtype,
                    value=self.clip_norm)
608 609 610 611
                scale_var = layers.elementwise_div(
                    x=max_global_norm,
                    y=layers.elementwise_max(
                        x=max_global_norm, y=global_norm_var))
612
            param_new_grad_name_dict = dict()
613 614 615
            for p, g in params_grads:
                if g is None:
                    continue
616
                if getattr(p, 'need_clip', True) is False:
617 618 619 620
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
621
                    new_g = _cast_to_mp_type_if_enabled(g)
W
WangXi 已提交
622
                    # inplace
623 624
                    scale_input = (scale_var.astype('float16') if
                                   new_g.dtype == core.VarDesc.VarType.FP16 and
625 626
                                   scale_var.dtype != core.VarDesc.VarType.FP16
                                   else scale_var)
627 628 629 630 631 632
                    # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
                    # will be in different blocks with the gradient clip related ops.
                    # We need to handle the correct block, otherwise will encounter
                    # a 'NotFoundError' during compile time.
                    block = default_main_program().current_block()
                    block.append_op(
W
WangXi 已提交
633
                        type='elementwise_mul',
634
                        inputs={'X': new_g,
635
                                'Y': scale_input},
636 637 638 639 640 641 642 643 644 645
                        outputs={'Out': new_g})
                    if new_g is not g:
                        block.append_op(
                            type='cast',
                            inputs={'X': new_g},
                            outputs={'Out': g},
                            attrs={
                                'in_dtype': new_g.dtype,
                                'out_dtype': g.dtype
                            })
646

W
WangXi 已提交
647 648
                param_new_grad_name_dict[p.name] = g.name
                params_and_grads.append((p, g))
649

650
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
651
        return params_and_grads
F
fengjiayi 已提交
652

Y
yuyang18 已提交
653
    def _process_context(self, context, param, grad):
F
update  
fengjiayi 已提交
654 655 656 657
        if self.group_name not in context:
            context[self.group_name] = []
            context[self.group_name + "_clip_value"] = self.clip_norm
            context[self.group_name + "_clip"] = layers.fill_constant(
658
                shape=[1], dtype=grad.dtype, value=self.clip_norm)
F
update  
fengjiayi 已提交
659 660 661 662 663
        else:
            if not self.clip_norm == context[self.group_name + "_clip_value"]:
                raise ValueError(
                    "All parameters' 'clip_norm' of a same group should be the same"
                )
F
fengjiayi 已提交
664

C
chengduo 已提交
665 666 667 668 669
        merge_grad = grad
        if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
            merge_grad = layers.merge_selected_rows(grad)
            merge_grad = layers.get_tensor_from_selected_rows(merge_grad)

W
WangXi 已提交
670
        local_norm_var = _squared_l2_norm(merge_grad)
F
update  
fengjiayi 已提交
671
        context[self.group_name].append(local_norm_var)
F
fengjiayi 已提交
672

F
update  
fengjiayi 已提交
673
        self.context = context
674

Y
yuyang18 已提交
675
    def _create_operators(self, param, grad):
F
update  
fengjiayi 已提交
676 677 678
        group_scale_name = self.group_name + "_scale"
        if group_scale_name not in self.context:
            group_norm_var = layers.sums(input=self.context[self.group_name])
T
tensor-tang 已提交
679
            group_norm_var = layers.sqrt(x=group_norm_var)
F
update  
fengjiayi 已提交
680 681 682
            clip_var = self.context[self.group_name + "_clip"]
            group_scale_var = layers.elementwise_div(
                x=clip_var,
F
fengjiayi 已提交
683
                y=layers.elementwise_max(
F
update  
fengjiayi 已提交
684
                    x=clip_var, y=group_norm_var))
685
            assert group_scale_var.shape == (1, )
F
update  
fengjiayi 已提交
686
            self.context[group_scale_name] = group_scale_var
F
fengjiayi 已提交
687

W
WangXi 已提交
688 689 690 691 692 693
        # inplace
        param.block.append_op(
            type='elementwise_mul',
            inputs={'X': grad,
                    'Y': self.context[group_scale_name]},
            outputs={'Out': grad})
C
chengduo 已提交
694

W
WangXi 已提交
695
        return param, grad
F
fengjiayi 已提交
696 697


698
@framework.dygraph_not_support
F
fengjiayi 已提交
699
def set_gradient_clip(clip, param_list=None, program=None):
F
fengjiayi 已提交
700
    """
701 702
    :api_attr: Static Graph
    
703 704 705 706
    Warning:
    
        This API must be used after building network, and before ``minimize`` , 
        and it may be removed in future releases, so it is not recommended. 
707 708 709 710
        It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
        this is a better method to clip gradient. There are three clipping strategies:
         :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , 
         :ref:`api_fluid_clip_GradientClipByValue` .
711
        
712 713 714
    To specify parameters that require gradient clip.

    Args:
715 716 717 718 719
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of 
            some derived class of ``GradientClipBase`` . There are three cliping strategies 
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , 
            :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no 
            gradient clipping.
Z
Zeng Jinle 已提交
720
        param_list (list(Variable), optional): Parameters that require gradient clip.
721
                It can be a list of parameter or a list of parameter's name.
722
                Default None, meaning that all parameters in the program will be included.
Z
Zeng Jinle 已提交
723
        program (Program, optional): The program where parameters are located.
724 725 726 727 728 729 730
                Default None, meaning that using :ref:`api_fluid_default_main_program` .

    Returns:
        None

    Examples:
        .. code-block:: python
C
Chengmo 已提交
731

732 733 734
            import paddle.fluid as fluid

            def network():
C
Chengmo 已提交
735 736
                image = fluid.data(name='image', shape=[
                                   None, 28], dtype='float32')
737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761
                param_attr1 = fluid.ParamAttr("fc1_param")
                fc1 = fluid.layers.fc(image, size=10, param_attr=param_attr1)
                param_attr2 = fluid.ParamAttr("fc2_param")
                fc2 = fluid.layers.fc(fc1, size=10, param_attr=param_attr2)
                loss = fluid.layers.reduce_mean(fc2)
                return loss


            # network 1: clip all parameter gradient
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

            # network 2: clip parameter gradient by name
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=["fc1_param", "fc2_param"])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

762
            # network 3: clip parameter gradient by value
763 764 765 766 767 768 769 770 771
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                param_var1 = fluid.default_main_program().global_block().var("fc1_param")
                param_var2 = fluid.default_main_program().global_block().var("fc2_param")
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=[param_var1, param_var2])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)
772
            
773
            # network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
774 775 776 777 778 779 780
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0)
                clip2 = fluid.clip.GradientClipByNorm(clip_norm=1.0)
                # Set the gradient clipping strategy: clip1
                fluid.clip.set_gradient_clip(clip1)
                # Set the gradient clipping strategy: clip2
781 782
                sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
                sgd.minimize(loss)
783 784 785 786
                # 'set_gradient_clip' will not take effect when setting has a conflict, 
                # and the gradient clipping strategy will be 'clip2'
            
            
F
fengjiayi 已提交
787
    """
788 789
    warnings.warn("Caution! 'set_gradient_clip' is not recommended "
                  "and may be deprecated in future! "
790 791
                  "We recommend a new strategy: set 'grad_clip' "
                  "when initializing the 'optimizer'. "
792
                  "This method can reduce the mistakes, please "
793
                  "refer to documention of 'optimizer'.")
794

795
    if not isinstance(clip, ClipGradBase):
F
fengjiayi 已提交
796
        raise TypeError(
797
            "'clip' should be an instance of ClipGradBase's derived class")
F
fengjiayi 已提交
798 799
    if program is None:
        program = framework.default_main_program()
800 801 802 803 804 805 806 807 808 809

    for op in program.block(0).ops:
        if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
                "op_namescope"):
            warnings.warn(
                "'minimize' has been invoked before, this will make 'set_gradient_clip' "
                "be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
            )
            break

F
fengjiayi 已提交
810 811
    if param_list is None:
        param_list = program.block(0).all_parameters()
812
    if all(isinstance(elem, six.string_types) for elem in param_list):
F
fengjiayi 已提交
813 814 815 816 817 818 819
        param_list = [program.block(0).var(elem) for elem in param_list]
    if not all(isinstance(elem, framework.Parameter) for elem in param_list):
        raise TypeError(
            "'param_list' should be a list of Parameter or basestring(parameter's name)."
        )

    for param in param_list:
F
fengjiayi 已提交
820
        param.gradient_clip_attr = copy.deepcopy(clip)
F
fengjiayi 已提交
821 822


823
def append_gradient_clip_ops(param_grads):
Y
Yu Yang 已提交
824
    context = dict()
825 826 827
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
828
        with p.block.program._optimized_guard(
829
            [p, g]), framework.name_scope('gradient_clip'):
830
            clip_attr = getattr(p, 'gradient_clip_attr', None)
Y
yuyang18 已提交
831
            if clip_attr is None:
832
                return param_grads
833
            if not isinstance(clip_attr, ClipGradBase):
Y
yuyang18 已提交
834
                raise TypeError(
835
                    "clip attribute should be an instance of GradientClipBase")
Y
Yu Yang 已提交
836

Y
yuyang18 已提交
837
            clip_attr._process_context(context=context, param=p, grad=g)
Y
yuyang18 已提交
838 839

    res = []
840
    param_new_grad_name_dict = dict()
841 842 843
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
844
        with p.block.program._optimized_guard(
845
            [p, g]), framework.name_scope('gradient_clip'):
846
            param, new_grad = clip_attr._create_operators(param=p, grad=g)
847
            param_new_grad_name_dict[param.name] = new_grad.name
848
            res.append([param, new_grad])
Y
Yu Yang 已提交
849

850
    _correct_clip_op_role_var(res, param_new_grad_name_dict)
851 852 853 854
    return res


# change wrong mapping relation between param & grad in clip op
855 856
# Note: This function is sensitive to the time cost of the network with gradient clipping 
# and should not be changed easily. If you must change, please test the time cost.
857 858 859 860
def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
    block_id_list = []
    if len(param_new_grad_name_dict) == 0:
        return
861 862
    for param, grad in params_grads:
        if grad is None:
863
            continue
864 865 866 867
        block_id = param.block.idx
        if block_id in block_id_list:
            continue
        block_id_list.append(block_id)
868
        for op in param.block.program.global_block().ops:
W
WangXi 已提交
869
            if op.has_attr("op_namescope") and "gradient_clip" in op.attr(
870 871 872 873 874 875
                    "op_namescope") and op.attr('op_role_var'):
                param_name = op.attr('op_role_var')[0]
                if param_name in param_new_grad_name_dict:
                    correct_p_g = [
                        param_name, param_new_grad_name_dict[param_name]
                    ]
C
Chengmo 已提交
876
                    op._set_attr('op_role_var', correct_p_g)
Y
Yu Yang 已提交
877 878


879 880 881 882
GradientClipBase = ClipGradBase
GradientClipByValue = ClipGradByValue
GradientClipByNorm = ClipGradByNorm
GradientClipByGlobalNorm = ClipGradByGlobalNorm