clip.py 33.3 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
update  
fengjiayi 已提交
14

15 16
from __future__ import print_function

F
fengjiayi 已提交
17
import copy
18
import six
19
import warnings
F
fengjiayi 已提交
20

Y
Yu Yang 已提交
21
import functools
W
WangXi 已提交
22
import paddle
23 24
from . import layers
from . import framework
F
fengjiayi 已提交
25
from . import core
C
Chengmo 已提交
26
from . import name_scope
27
from .dygraph import base as imperative_base
W
WangXi 已提交
28 29 30
from .data_feeder import check_variable_and_dtype
from .framework import in_dygraph_mode
from .layer_helper import LayerHelper
31
from .framework import default_main_program
32
from paddle import _C_ops
Y
Yu Yang 已提交
33

F
fengjiayi 已提交
34
__all__ = [
35 36
    'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
    'ClipGradByNorm', 'ClipGradByGlobalNorm'
F
fengjiayi 已提交
37
]
Y
Yu Yang 已提交
38 39


W
WangXi 已提交
40 41 42 43 44
def _squared_l2_norm(x):
    r"""
    This OP returns the squared L2 norm of a tensor.
    """

45
    if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16:
W
WangXi 已提交
46 47 48 49 50
        square = layers.square(x)
        sum_square = layers.reduce_sum(square)
        return sum_square

    if in_dygraph_mode():
51
        return _C_ops.squared_l2_norm(x)
W
WangXi 已提交
52 53

    op_type = 'squared_l2_norm'
54
    check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
W
WangXi 已提交
55 56 57 58 59 60 61 62 63
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(x.dtype)

    inputs = {"X": x}
    outputs = {'Out': out}
    helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
    return out


F
fengjiayi 已提交
64
class BaseErrorClipAttr(object):
F
fengjiayi 已提交
65 66 67
    def __str__(self):
        raise NotImplementedError()

Y
yuyang18 已提交
68
    def _append_clip_op(self, block, grad_name):
F
fengjiayi 已提交
69 70 71 72
        raise NotImplementedError()


class ErrorClipByValue(BaseErrorClipAttr):
73
    r"""
74 75
    Clips tensor values to the range [min, max].

76 77
    Given a tensor ``t`` (see Examples below), this operation clips its value \
    to ``min`` and ``max`` inplace.
78 79 80 81 82 83 84

    - Any values less than min are set to min.
    - Any values greater than max are set to max.

    Args:
        max (float): The maximum value to clip by.
        min (float, optional): The minimum value to clip by. if not set by user, \
85
        will be set to ``-max`` by framework.
86 87 88 89

    Examples:
        .. code-block:: python

90 91 92 93 94 95
            import paddle.fluid as fluid
            BATCH_SIZE = 128
            CLIP_MAX = 2e-6
            CLIP_MIN = -1e-6
            prog = fluid.framework.Program()
            with fluid.program_guard(main_program=prog):
C
Chengmo 已提交
96 97
                image = fluid.layers.data(
                    name='x', shape=[784], dtype='float32')
98 99
                hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
                hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
C
Chengmo 已提交
100 101
                predict = fluid.layers.fc(
                    input=hidden2, size=10, act='softmax')
102 103 104 105 106 107 108
                label = fluid.layers.data(name='y', shape=[1], dtype='int64')
                cost = fluid.layers.cross_entropy(input=predict, label=label)
                avg_cost = fluid.layers.mean(cost)
            prog_clip = prog.clone()
            prog_clip.block(0).var(hidden1.name)._set_error_clip(
                fluid.clip.ErrorClipByValue(
                    max=CLIP_MAX, min=CLIP_MIN)
109 110
    """

F
fengjiayi 已提交
111 112 113 114 115 116 117 118 119
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
120 121 122
    def __str__(self):
        return "ByValue, min=%f, max=%f" % (self.min, self.max)

Y
yuyang18 已提交
123
    def _append_clip_op(self, block, grad_name):
124 125 126 127
        clip_op_desc = block.desc.append_op()
        clip_op_desc.set_type("clip")
        clip_op_desc.set_input("X", [grad_name])
        clip_op_desc.set_output("Out", [grad_name])
W
Wu Yi 已提交
128 129
        clip_op_desc._set_attr("min", self.min)
        clip_op_desc._set_attr("max", self.max)
F
fengjiayi 已提交
130 131 132 133 134 135


def error_clip_callback(block, context):
    # the context is a grad_to_var map
    grad_to_var = context
    op_desc = block.desc.op(block.desc.op_size() - 1)
136
    for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
W
Wu Yi 已提交
137
        fwd_var = block._var_recursive(grad_to_var[grad_n])
F
fengjiayi 已提交
138
        error_clip = getattr(fwd_var, "error_clip", None)
F
fengjiayi 已提交
139 140 141 142 143
        if not (error_clip is None or isinstance(error_clip,
                                                 BaseErrorClipAttr)):
            raise TypeError(
                "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
            )
F
fengjiayi 已提交
144
        if error_clip is not None:
Y
yuyang18 已提交
145
            error_clip._append_clip_op(block, grad_n)
F
fengjiayi 已提交
146 147


148 149 150
class ClipGradBase(object):
    def __init__(self):
        super(ClipGradBase, self).__init__()
151

F
fengjiayi 已提交
152 153 154
    def __str__(self):
        raise NotImplementedError()

155
    @imperative_base.no_grad
156 157
    def _dygraph_clip(self, params_grads):
        raise NotImplementedError
Y
Yu Yang 已提交
158

159 160
    def _static_clip(self, params_grads):
        raise NotImplementedError
Y
Yu Yang 已提交
161

162 163 164 165 166 167 168 169
    def __call__(self, params_grads):
        if framework.in_dygraph_mode():
            return self._dygraph_clip(params_grads)
        else:
            for p, g in params_grads:
                if getattr(p, 'gradient_clip_attr', None) is not None:
                    warnings.warn(
                        "'set_gradient_clip' will be ineffective, because you have "
170
                        "set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' "
171 172 173
                        "is redundant and you can remove it.")
                    break
            return self._static_clip(params_grads)
F
fengjiayi 已提交
174

Y
yuyang18 已提交
175
    def _process_context(self, context, param, grad):
176
        raise NotImplementedError()
Y
Yu Yang 已提交
177

Y
yuyang18 已提交
178
    def _create_operators(self, param, grad):
179
        raise NotImplementedError()
Y
Yu Yang 已提交
180 181


182
class ClipGradByValue(ClipGradBase):
183
    """
184 185
    Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max].
    
186
    - Any values less than min are set to ``min``.
187
    
188
    - Any values greater than max are set to ``max``.
189

190 191
    The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``. 
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
192
    
193
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
194
    (for example: :ref:`api_paddle_optimizer_SGD`).
195 196 197 198

    Note:
        ``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
199
    
200 201
    Args:
        max (float): The maximum value to clip by.
202 203
        min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max`` 
            automatically. In this case, ``max`` must be greater than 0.
204 205 206

    Examples:
        .. code-block:: python
207 208
        
            import paddle
209

210
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
211 212 213
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
214 215 216 217
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

218
            clip = paddle.nn.ClipGradByValue(min=-1, max=1)
219 220
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
221 222
    """

223 224
    def __init__(self, max, min=None):
        super(ClipGradByValue, self).__init__()
Y
Yu Yang 已提交
225
        if min is None:
226
            assert (max > 0.0)
Y
Yu Yang 已提交
227
            min = -max
228 229
        self.max = float(max)
        self.min = float(min)
Y
Yu Yang 已提交
230

F
fengjiayi 已提交
231
    def __str__(self):
232
        return "Clip Gradient By Value, min = %f, max=%f" % (self.min, self.max)
233

234
    @imperative_base.no_grad
235 236 237 238 239
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        for p, g in params_grads:
            if g is None:
                continue
240
            if getattr(p, 'need_clip', True) is False:
241 242 243 244 245 246 247 248
                params_and_grads.append((p, g))
                continue
            new_grad = layers.clip(x=g, min=self.min, max=self.max)
            params_and_grads.append((p, new_grad))
        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
249
        param_new_grad_name_dict = dict()
250 251 252 253
        with framework.name_scope('gradient_clip'):
            for p, g in params_grads:
                if g is None:
                    continue
254
                if getattr(p, 'need_clip', True) is False:
255 256 257 258 259 260
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
                    new_grad = layers.clip(x=g, min=self.min, max=self.max)
                params_and_grads.append((p, new_grad))
261 262
                param_new_grad_name_dict[p.name] = new_grad.name
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
263
        return params_and_grads
F
fengjiayi 已提交
264

Y
yuyang18 已提交
265
    def _process_context(self, context, param, grad):
Y
Yu Yang 已提交
266 267
        pass

Y
yuyang18 已提交
268
    def _create_operators(self, param, grad):
Y
Yu Yang 已提交
269 270 271 272
        new_grad = layers.clip(x=grad, min=self.min, max=self.max)
        return param, new_grad


273
class ClipGradByNorm(ClipGradBase):
274
    r"""
275 276 277 278 279 280
    Limit the l2 norm of multi-dimensional Tensor :math:`X` to ``clip_norm`` .
    
    - If the l2 norm of :math:`X` is greater than ``clip_norm`` , :math:`X` will be compressed by a ratio.
    
    - If the l2 norm of :math:`X` is less than or equal to ``clip_norm`` , nothing will be done.
    
281 282
    The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
283
    
284
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
285
    (for example: :ref:`api_paddle_optimizer_SGD`).
286 287
    
    The clipping formula is:
288 289

    .. math::
290
        Out =
291 292 293 294 295 296
        \left\{
            \begin{array}{ccl}
                X & & if (norm(X) \leq clip\_norm) \\
                \frac{clip\_norm*X}{norm(X)} & & if (norm(X) > clip\_norm) \\
        \end{array}
        \right.
297 298 299 300


    where :math:`norm(X)` represents the L2 norm of :math:`X`.

301
    .. math::
302
        norm(X) = ( \sum_{i=1}^{n}|x\_i|^2)^{ \frac{1}{2}}
303

304 305 306 307
    Note:
        ``need_clip`` of ``ClipGradByNorm`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

308
    Args:
309
        clip_norm(float): The maximum norm value.
C
Chengmo 已提交
310

311 312
    Examples:
        .. code-block:: python
313 314
        
            import paddle
315

316
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
317 318 319
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
320 321 322 323
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

324
            clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
325 326
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
327 328
    """

329 330
    def __init__(self, clip_norm):
        super(ClipGradByNorm, self).__init__()
331
        self.clip_norm = float(clip_norm)
F
fengjiayi 已提交
332

F
fengjiayi 已提交
333
    def __str__(self):
334 335
        return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm

336
    @imperative_base.no_grad
337 338 339 340 341
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        for p, g in params_grads:
            if g is None:
                continue
342
            if getattr(p, 'need_clip', True) is False:
343 344 345 346 347 348 349 350 351
                params_and_grads.append((p, g))
                continue
            new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
            params_and_grads.append((p, new_grad))
        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
        with framework.name_scope('gradient_clip'):
352
            param_new_grad_name_dict = dict()
353 354 355
            for p, g in params_grads:
                if g is None:
                    continue
356
                if getattr(p, 'need_clip', True) is False:
357 358 359 360 361
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
                    new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
362
                param_new_grad_name_dict[p.name] = new_grad.name
363
                params_and_grads.append((p, new_grad))
364
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
365
        return params_and_grads
F
fengjiayi 已提交
366

Y
yuyang18 已提交
367
    def _process_context(self, context, param, grad):
F
fengjiayi 已提交
368 369
        pass

Y
yuyang18 已提交
370
    def _create_operators(self, param, grad):
F
fengjiayi 已提交
371 372 373 374
        new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
        return param, new_grad


375 376 377 378 379 380 381 382 383 384 385 386 387 388
_allow_pure_fp16_global_norm_clip_flag = False


def _allow_pure_fp16_global_norm_clip(*args):
    global _allow_pure_fp16_global_norm_clip_flag
    if len(args) == 0:
        return _allow_pure_fp16_global_norm_clip_flag
    else:
        assert len(args) == 1 and isinstance(args[0], bool)
        old_value = _allow_pure_fp16_global_norm_clip_flag
        _allow_pure_fp16_global_norm_clip_flag = args[0]
        return old_value


389
class ClipGradByGlobalNorm(ClipGradBase):
390
    r"""
391 392 393 394 395 396 397
    Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in 
    :math:`t\_list` , and limit it to ``clip_norm`` .
    
    - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
    
    - If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
    
398 399
    The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
400
    
401
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
402
    (for example: :ref:`api_paddle_optimizer_SGD`).
403 404

    The clipping formula is:
405 406 407

    .. math::

408
        t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}
409 410 411 412 413 414 415

    where:

    .. math::

        global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}

416 417 418 419
    Note:
        ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

420
    Args:
421
        clip_norm (float): The maximum norm value.
422
        group_name (str, optional): The group name for this clip. Default value is ``default_group``.
423 424 425

    Examples:
        .. code-block:: python
426
        
427 428
            import paddle

429
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
430 431 432
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
433 434 435 436
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

437
            clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
438 439
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
440 441
    """

442 443
    def __init__(self, clip_norm, group_name="default_group"):
        super(ClipGradByGlobalNorm, self).__init__()
444
        self.clip_norm = float(clip_norm)
F
update  
fengjiayi 已提交
445
        self.group_name = group_name
446

F
fengjiayi 已提交
447
    def __str__(self):
448 449
        return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)

450
    @imperative_base.no_grad
451 452 453
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        sum_square_list = []
454 455
        sum_square_list_fp16 = []
        sum_square_list_fp32 = []
456 457 458
        for p, g in params_grads:
            if g is None:
                continue
459
            if getattr(p, 'need_clip', True) is False:
460 461 462 463 464
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.merge_selected_rows(g)
                merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
W
WangXi 已提交
465 466

            sum_square = _squared_l2_norm(merge_grad)
467 468 469 470 471 472
            if sum_square.dtype == core.VarDesc.VarType.FP16:
                sum_square_list_fp16.append(sum_square)
            elif sum_square.dtype == core.VarDesc.VarType.FP32:
                sum_square_list_fp32.append(sum_square)
            else:
                sum_square_list.append(sum_square)
473 474

        # all parameters have been filterd out
475 476
        if len(sum_square_list) + len(sum_square_list_fp16) + len(
                sum_square_list_fp32) == 0:
477 478
            return params_grads

479 480 481
        sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
        global_norm_var = []
        if len(sum_square_list_fp16) > 0:
Z
zhangbo9674 已提交
482
            global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
483 484
            global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
        if len(sum_square_list_fp32) > 0:
Z
zhangbo9674 已提交
485
            global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
486 487 488 489 490
            if sum_dtype == 'float32':
                global_norm_var.append(global_norm_var_fp32)
            else:
                global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
        if len(sum_square_list) > 0:
Z
zhangbo9674 已提交
491
            global_norm_var_fp64 = paddle.add_n(sum_square_list)
492
            global_norm_var.append(global_norm_var_fp64)
Z
zhangbo9674 已提交
493
        global_norm_var = paddle.add_n(global_norm_var)
494 495
        global_norm_var = layers.sqrt(global_norm_var)
        max_global_norm = layers.fill_constant(
496
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
Z
zhangbo9674 已提交
497 498 499 500 501 502 503 504 505

        # only when global_norm_var > max_global_norm, grad need clip
        need_clip = False
        if global_norm_var > max_global_norm:
            need_clip = True

        if need_clip:
            clip_var = layers.elementwise_div(
                x=max_global_norm, y=global_norm_var)
506 507 508
        for p, g in params_grads:
            if g is None:
                continue
509
            if getattr(p, 'need_clip', True) is False:
510 511
                params_and_grads.append((p, g))
                continue
W
WangXi 已提交
512
            # TODO(wangxi): use inplace elementwise_mul
Z
zhangbo9674 已提交
513 514 515 516 517 518 519 520
            if need_clip:
                clip_input = (clip_var.astype('float16')
                              if g.dtype == core.VarDesc.VarType.FP16 else
                              clip_var)
                new_grad = layers.elementwise_mul(x=g, y=clip_input)
                params_and_grads.append((p, new_grad))
            else:
                params_and_grads.append((p, g))
521 522 523 524 525 526

        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
        sum_square_list = []
527 528
        sum_square_list_fp16 = []
        sum_square_list_fp32 = []
529 530 531 532
        with framework.name_scope('gradient_clip'):
            for p, g in params_grads:
                if g is None:
                    continue
533
                if getattr(p, 'need_clip', True) is False:
534 535 536 537 538 539 540
                    continue
                merge_grad = g
                with p.block.program._optimized_guard([p, g]):
                    if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                        merge_grad = layers.merge_selected_rows(g)
                        merge_grad = layers.get_tensor_from_selected_rows(
                            merge_grad)
W
WangXi 已提交
541
                    sum_square = _squared_l2_norm(merge_grad)
542 543 544 545 546 547
                    if sum_square.dtype == core.VarDesc.VarType.FP16:
                        sum_square_list_fp16.append(sum_square)
                    elif sum_square.dtype == core.VarDesc.VarType.FP32:
                        sum_square_list_fp32.append(sum_square)
                    else:
                        sum_square_list.append(sum_square)
548 549

            # all parameters have been filterd out
550 551
            if len(sum_square_list) + len(sum_square_list_fp16) + len(
                    sum_square_list_fp32) == 0:
552 553 554
                return params_grads

            with p.block.program._optimized_guard([p, g]):
555 556 557 558 559
                sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"

                global_norm_var = []
                if len(sum_square_list_fp16) > 0:
                    global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
560 561 562 563 564 565
                    if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip(
                    ):
                        global_norm_var.append(
                            global_norm_var_fp16.astype(sum_dtype))
                    else:
                        global_norm_var.append(global_norm_var_fp16)
566 567 568 569 570 571 572 573 574 575 576
                if len(sum_square_list_fp32) > 0:
                    global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
                    if sum_dtype == 'float32':
                        global_norm_var.append(global_norm_var_fp32)
                    else:
                        global_norm_var.append(
                            global_norm_var_fp32.astype(sum_dtype))
                if len(sum_square_list) > 0:
                    # fp64
                    global_norm_var_other_dtype = layers.sums(sum_square_list)
                    global_norm_var.append(global_norm_var_other_dtype)
577 578 579

                global_norm_var = layers.sums(global_norm_var) if len(
                    global_norm_var) > 1 else global_norm_var[0]
580 581
                global_norm_var = layers.sqrt(x=global_norm_var)
                max_global_norm = layers.fill_constant(
582 583 584
                    shape=[1],
                    dtype=global_norm_var.dtype,
                    value=self.clip_norm)
585 586 587 588
                scale_var = layers.elementwise_div(
                    x=max_global_norm,
                    y=layers.elementwise_max(
                        x=max_global_norm, y=global_norm_var))
589
            param_new_grad_name_dict = dict()
590 591 592
            for p, g in params_grads:
                if g is None:
                    continue
593
                if getattr(p, 'need_clip', True) is False:
594 595 596 597
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
W
WangXi 已提交
598
                    # inplace
599
                    scale_input = (scale_var.astype('float16')
600 601 602
                                   if g.dtype == core.VarDesc.VarType.FP16 and
                                   scale_var.dtype != core.VarDesc.VarType.FP16
                                   else scale_var)
603 604 605 606 607 608
                    # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
                    # will be in different blocks with the gradient clip related ops.
                    # We need to handle the correct block, otherwise will encounter
                    # a 'NotFoundError' during compile time.
                    block = default_main_program().current_block()
                    block.append_op(
W
WangXi 已提交
609 610
                        type='elementwise_mul',
                        inputs={'X': g,
611
                                'Y': scale_input},
W
WangXi 已提交
612
                        outputs={'Out': g})
613

W
WangXi 已提交
614 615
                param_new_grad_name_dict[p.name] = g.name
                params_and_grads.append((p, g))
616

617
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
618
        return params_and_grads
F
fengjiayi 已提交
619

Y
yuyang18 已提交
620
    def _process_context(self, context, param, grad):
F
update  
fengjiayi 已提交
621 622 623 624
        if self.group_name not in context:
            context[self.group_name] = []
            context[self.group_name + "_clip_value"] = self.clip_norm
            context[self.group_name + "_clip"] = layers.fill_constant(
625
                shape=[1], dtype=grad.dtype, value=self.clip_norm)
F
update  
fengjiayi 已提交
626 627 628 629 630
        else:
            if not self.clip_norm == context[self.group_name + "_clip_value"]:
                raise ValueError(
                    "All parameters' 'clip_norm' of a same group should be the same"
                )
F
fengjiayi 已提交
631

C
chengduo 已提交
632 633 634 635 636
        merge_grad = grad
        if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
            merge_grad = layers.merge_selected_rows(grad)
            merge_grad = layers.get_tensor_from_selected_rows(merge_grad)

W
WangXi 已提交
637
        local_norm_var = _squared_l2_norm(merge_grad)
F
update  
fengjiayi 已提交
638
        context[self.group_name].append(local_norm_var)
F
fengjiayi 已提交
639

F
update  
fengjiayi 已提交
640
        self.context = context
641

Y
yuyang18 已提交
642
    def _create_operators(self, param, grad):
F
update  
fengjiayi 已提交
643 644 645
        group_scale_name = self.group_name + "_scale"
        if group_scale_name not in self.context:
            group_norm_var = layers.sums(input=self.context[self.group_name])
T
tensor-tang 已提交
646
            group_norm_var = layers.sqrt(x=group_norm_var)
F
update  
fengjiayi 已提交
647 648 649
            clip_var = self.context[self.group_name + "_clip"]
            group_scale_var = layers.elementwise_div(
                x=clip_var,
F
fengjiayi 已提交
650
                y=layers.elementwise_max(
F
update  
fengjiayi 已提交
651
                    x=clip_var, y=group_norm_var))
652
            assert group_scale_var.shape == (1, )
F
update  
fengjiayi 已提交
653
            self.context[group_scale_name] = group_scale_var
F
fengjiayi 已提交
654

W
WangXi 已提交
655 656 657 658 659 660
        # inplace
        param.block.append_op(
            type='elementwise_mul',
            inputs={'X': grad,
                    'Y': self.context[group_scale_name]},
            outputs={'Out': grad})
C
chengduo 已提交
661

W
WangXi 已提交
662
        return param, grad
F
fengjiayi 已提交
663 664


665
@framework.dygraph_not_support
F
fengjiayi 已提交
666
def set_gradient_clip(clip, param_list=None, program=None):
F
fengjiayi 已提交
667
    """
668 669
    :api_attr: Static Graph
    
670 671 672 673
    Warning:
    
        This API must be used after building network, and before ``minimize`` , 
        and it may be removed in future releases, so it is not recommended. 
674 675 676 677
        It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
        this is a better method to clip gradient. There are three clipping strategies:
         :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , 
         :ref:`api_fluid_clip_GradientClipByValue` .
678
        
679 680 681
    To specify parameters that require gradient clip.

    Args:
682 683 684 685 686
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of 
            some derived class of ``GradientClipBase`` . There are three cliping strategies 
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , 
            :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no 
            gradient clipping.
Z
Zeng Jinle 已提交
687
        param_list (list(Variable), optional): Parameters that require gradient clip.
688
                It can be a list of parameter or a list of parameter's name.
689
                Default None, meaning that all parameters in the program will be included.
Z
Zeng Jinle 已提交
690
        program (Program, optional): The program where parameters are located.
691 692 693 694 695 696 697
                Default None, meaning that using :ref:`api_fluid_default_main_program` .

    Returns:
        None

    Examples:
        .. code-block:: python
C
Chengmo 已提交
698

699 700 701
            import paddle.fluid as fluid

            def network():
C
Chengmo 已提交
702 703
                image = fluid.data(name='image', shape=[
                                   None, 28], dtype='float32')
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
                param_attr1 = fluid.ParamAttr("fc1_param")
                fc1 = fluid.layers.fc(image, size=10, param_attr=param_attr1)
                param_attr2 = fluid.ParamAttr("fc2_param")
                fc2 = fluid.layers.fc(fc1, size=10, param_attr=param_attr2)
                loss = fluid.layers.reduce_mean(fc2)
                return loss


            # network 1: clip all parameter gradient
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

            # network 2: clip parameter gradient by name
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=["fc1_param", "fc2_param"])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

729
            # network 3: clip parameter gradient by value
730 731 732 733 734 735 736 737 738
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                param_var1 = fluid.default_main_program().global_block().var("fc1_param")
                param_var2 = fluid.default_main_program().global_block().var("fc2_param")
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=[param_var1, param_var2])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)
739
            
740
            # network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
741 742 743 744 745 746 747
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0)
                clip2 = fluid.clip.GradientClipByNorm(clip_norm=1.0)
                # Set the gradient clipping strategy: clip1
                fluid.clip.set_gradient_clip(clip1)
                # Set the gradient clipping strategy: clip2
748 749
                sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
                sgd.minimize(loss)
750 751 752 753
                # 'set_gradient_clip' will not take effect when setting has a conflict, 
                # and the gradient clipping strategy will be 'clip2'
            
            
F
fengjiayi 已提交
754
    """
755 756
    warnings.warn("Caution! 'set_gradient_clip' is not recommended "
                  "and may be deprecated in future! "
757 758
                  "We recommend a new strategy: set 'grad_clip' "
                  "when initializing the 'optimizer'. "
759
                  "This method can reduce the mistakes, please "
760
                  "refer to documention of 'optimizer'.")
761

762
    if not isinstance(clip, ClipGradBase):
F
fengjiayi 已提交
763
        raise TypeError(
764
            "'clip' should be an instance of ClipGradBase's derived class")
F
fengjiayi 已提交
765 766
    if program is None:
        program = framework.default_main_program()
767 768 769 770 771 772 773 774 775 776

    for op in program.block(0).ops:
        if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
                "op_namescope"):
            warnings.warn(
                "'minimize' has been invoked before, this will make 'set_gradient_clip' "
                "be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
            )
            break

F
fengjiayi 已提交
777 778
    if param_list is None:
        param_list = program.block(0).all_parameters()
779
    if all(isinstance(elem, six.string_types) for elem in param_list):
F
fengjiayi 已提交
780 781 782 783 784 785 786
        param_list = [program.block(0).var(elem) for elem in param_list]
    if not all(isinstance(elem, framework.Parameter) for elem in param_list):
        raise TypeError(
            "'param_list' should be a list of Parameter or basestring(parameter's name)."
        )

    for param in param_list:
F
fengjiayi 已提交
787
        param.gradient_clip_attr = copy.deepcopy(clip)
F
fengjiayi 已提交
788 789


790
def append_gradient_clip_ops(param_grads):
Y
Yu Yang 已提交
791
    context = dict()
792 793 794
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
795
        with p.block.program._optimized_guard(
796
            [p, g]), framework.name_scope('gradient_clip'):
797
            clip_attr = getattr(p, 'gradient_clip_attr', None)
Y
yuyang18 已提交
798
            if clip_attr is None:
799
                return param_grads
800
            if not isinstance(clip_attr, ClipGradBase):
Y
yuyang18 已提交
801
                raise TypeError(
802
                    "clip attribute should be an instance of GradientClipBase")
Y
Yu Yang 已提交
803

Y
yuyang18 已提交
804
            clip_attr._process_context(context=context, param=p, grad=g)
Y
yuyang18 已提交
805 806

    res = []
807
    param_new_grad_name_dict = dict()
808 809 810
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
811
        with p.block.program._optimized_guard(
812
            [p, g]), framework.name_scope('gradient_clip'):
813
            param, new_grad = clip_attr._create_operators(param=p, grad=g)
814
            param_new_grad_name_dict[param.name] = new_grad.name
815
            res.append([param, new_grad])
Y
Yu Yang 已提交
816

817
    _correct_clip_op_role_var(res, param_new_grad_name_dict)
818 819 820 821
    return res


# change wrong mapping relation between param & grad in clip op
822 823
# Note: This function is sensitive to the time cost of the network with gradient clipping 
# and should not be changed easily. If you must change, please test the time cost.
824 825 826 827
def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
    block_id_list = []
    if len(param_new_grad_name_dict) == 0:
        return
828 829
    for param, grad in params_grads:
        if grad is None:
830
            continue
831 832 833 834
        block_id = param.block.idx
        if block_id in block_id_list:
            continue
        block_id_list.append(block_id)
835
        for op in param.block.program.global_block().ops:
W
WangXi 已提交
836
            if op.has_attr("op_namescope") and "gradient_clip" in op.attr(
837 838 839 840 841 842
                    "op_namescope") and op.attr('op_role_var'):
                param_name = op.attr('op_role_var')[0]
                if param_name in param_new_grad_name_dict:
                    correct_p_g = [
                        param_name, param_new_grad_name_dict[param_name]
                    ]
C
Chengmo 已提交
843
                    op._set_attr('op_role_var', correct_p_g)
Y
Yu Yang 已提交
844 845


846 847 848 849
GradientClipBase = ClipGradBase
GradientClipByValue = ClipGradByValue
GradientClipByNorm = ClipGradByNorm
GradientClipByGlobalNorm = ClipGradByGlobalNorm