clip.py 31.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
update  
fengjiayi 已提交
14

15 16
from __future__ import print_function

F
fengjiayi 已提交
17
import copy
18
import six
19
import warnings
F
fengjiayi 已提交
20

Y
Yu Yang 已提交
21
import functools
W
WangXi 已提交
22
import paddle
23 24
from . import layers
from . import framework
F
fengjiayi 已提交
25
from . import core
C
Chengmo 已提交
26
from . import name_scope
27
from .dygraph import base as imperative_base
W
WangXi 已提交
28 29 30
from .data_feeder import check_variable_and_dtype
from .framework import in_dygraph_mode
from .layer_helper import LayerHelper
31
from .framework import default_main_program
Y
Yu Yang 已提交
32

F
fengjiayi 已提交
33
__all__ = [
34 35
    'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
    'ClipGradByNorm', 'ClipGradByGlobalNorm'
F
fengjiayi 已提交
36
]
Y
Yu Yang 已提交
37 38


W
WangXi 已提交
39 40 41 42 43
def _squared_l2_norm(x):
    r"""
    This OP returns the squared L2 norm of a tensor.
    """

44
    if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16:
W
WangXi 已提交
45 46 47 48 49 50 51 52
        square = layers.square(x)
        sum_square = layers.reduce_sum(square)
        return sum_square

    if in_dygraph_mode():
        return core.ops.squared_l2_norm(x)

    op_type = 'squared_l2_norm'
53
    check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
W
WangXi 已提交
54 55 56 57 58 59 60 61 62
    helper = LayerHelper(op_type, **locals())
    out = helper.create_variable_for_type_inference(x.dtype)

    inputs = {"X": x}
    outputs = {'Out': out}
    helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
    return out


F
fengjiayi 已提交
63
class BaseErrorClipAttr(object):
F
fengjiayi 已提交
64 65 66
    def __str__(self):
        raise NotImplementedError()

Y
yuyang18 已提交
67
    def _append_clip_op(self, block, grad_name):
F
fengjiayi 已提交
68 69 70 71
        raise NotImplementedError()


class ErrorClipByValue(BaseErrorClipAttr):
72
    r"""
73 74
    Clips tensor values to the range [min, max].

75 76
    Given a tensor ``t`` (see Examples below), this operation clips its value \
    to ``min`` and ``max`` inplace.
77 78 79 80 81 82 83

    - Any values less than min are set to min.
    - Any values greater than max are set to max.

    Args:
        max (float): The maximum value to clip by.
        min (float, optional): The minimum value to clip by. if not set by user, \
84
        will be set to ``-max`` by framework.
85 86 87 88

    Examples:
        .. code-block:: python

89 90 91 92 93 94
            import paddle.fluid as fluid
            BATCH_SIZE = 128
            CLIP_MAX = 2e-6
            CLIP_MIN = -1e-6
            prog = fluid.framework.Program()
            with fluid.program_guard(main_program=prog):
C
Chengmo 已提交
95 96
                image = fluid.layers.data(
                    name='x', shape=[784], dtype='float32')
97 98
                hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
                hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
C
Chengmo 已提交
99 100
                predict = fluid.layers.fc(
                    input=hidden2, size=10, act='softmax')
101 102 103 104 105 106 107
                label = fluid.layers.data(name='y', shape=[1], dtype='int64')
                cost = fluid.layers.cross_entropy(input=predict, label=label)
                avg_cost = fluid.layers.mean(cost)
            prog_clip = prog.clone()
            prog_clip.block(0).var(hidden1.name)._set_error_clip(
                fluid.clip.ErrorClipByValue(
                    max=CLIP_MAX, min=CLIP_MIN)
108 109
    """

F
fengjiayi 已提交
110 111 112 113 114 115 116 117 118
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
119 120 121
    def __str__(self):
        return "ByValue, min=%f, max=%f" % (self.min, self.max)

Y
yuyang18 已提交
122
    def _append_clip_op(self, block, grad_name):
123 124 125 126
        clip_op_desc = block.desc.append_op()
        clip_op_desc.set_type("clip")
        clip_op_desc.set_input("X", [grad_name])
        clip_op_desc.set_output("Out", [grad_name])
W
Wu Yi 已提交
127 128
        clip_op_desc._set_attr("min", self.min)
        clip_op_desc._set_attr("max", self.max)
F
fengjiayi 已提交
129 130 131 132 133 134


def error_clip_callback(block, context):
    # the context is a grad_to_var map
    grad_to_var = context
    op_desc = block.desc.op(block.desc.op_size() - 1)
135
    for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
W
Wu Yi 已提交
136
        fwd_var = block._var_recursive(grad_to_var[grad_n])
F
fengjiayi 已提交
137
        error_clip = getattr(fwd_var, "error_clip", None)
F
fengjiayi 已提交
138 139 140 141 142
        if not (error_clip is None or isinstance(error_clip,
                                                 BaseErrorClipAttr)):
            raise TypeError(
                "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
            )
F
fengjiayi 已提交
143
        if error_clip is not None:
Y
yuyang18 已提交
144
            error_clip._append_clip_op(block, grad_n)
F
fengjiayi 已提交
145 146


147 148 149
class ClipGradBase(object):
    def __init__(self):
        super(ClipGradBase, self).__init__()
150

F
fengjiayi 已提交
151 152 153
    def __str__(self):
        raise NotImplementedError()

154
    @imperative_base.no_grad
155 156
    def _dygraph_clip(self, params_grads):
        raise NotImplementedError
Y
Yu Yang 已提交
157

158 159
    def _static_clip(self, params_grads):
        raise NotImplementedError
Y
Yu Yang 已提交
160

161 162 163 164 165 166 167 168
    def __call__(self, params_grads):
        if framework.in_dygraph_mode():
            return self._dygraph_clip(params_grads)
        else:
            for p, g in params_grads:
                if getattr(p, 'gradient_clip_attr', None) is not None:
                    warnings.warn(
                        "'set_gradient_clip' will be ineffective, because you have "
169
                        "set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' "
170 171 172
                        "is redundant and you can remove it.")
                    break
            return self._static_clip(params_grads)
F
fengjiayi 已提交
173

Y
yuyang18 已提交
174
    def _process_context(self, context, param, grad):
175
        raise NotImplementedError()
Y
Yu Yang 已提交
176

Y
yuyang18 已提交
177
    def _create_operators(self, param, grad):
178
        raise NotImplementedError()
Y
Yu Yang 已提交
179 180


181
class ClipGradByValue(ClipGradBase):
182
    """
183 184
    Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max].
    
185
    - Any values less than min are set to ``min``.
186
    
187
    - Any values greater than max are set to ``max``.
188

189 190
    The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``. 
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
191
    
192
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
193
    (for example: :ref:`api_paddle_optimizer_SGD`).
194 195 196 197

    Note:
        ``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
198
    
199 200
    Args:
        max (float): The maximum value to clip by.
201 202
        min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max`` 
            automatically. In this case, ``max`` must be greater than 0.
203 204 205

    Examples:
        .. code-block:: python
206 207
        
            import paddle
208

209
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
210 211 212
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
213 214 215 216
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

217
            clip = paddle.nn.ClipGradByValue(min=-1, max=1)
218 219
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
220 221
    """

222 223
    def __init__(self, max, min=None):
        super(ClipGradByValue, self).__init__()
Y
Yu Yang 已提交
224
        if min is None:
225
            assert (max > 0.0)
Y
Yu Yang 已提交
226
            min = -max
227 228
        self.max = float(max)
        self.min = float(min)
Y
Yu Yang 已提交
229

F
fengjiayi 已提交
230
    def __str__(self):
231
        return "Clip Gradient By Value, min = %f, max=%f" % (self.min, self.max)
232

233
    @imperative_base.no_grad
234 235 236 237 238
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        for p, g in params_grads:
            if g is None:
                continue
239
            if getattr(p, 'need_clip', True) is False:
240 241 242 243 244 245 246 247
                params_and_grads.append((p, g))
                continue
            new_grad = layers.clip(x=g, min=self.min, max=self.max)
            params_and_grads.append((p, new_grad))
        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
248
        param_new_grad_name_dict = dict()
249 250 251 252
        with framework.name_scope('gradient_clip'):
            for p, g in params_grads:
                if g is None:
                    continue
253
                if getattr(p, 'need_clip', True) is False:
254 255 256 257 258 259
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
                    new_grad = layers.clip(x=g, min=self.min, max=self.max)
                params_and_grads.append((p, new_grad))
260 261
                param_new_grad_name_dict[p.name] = new_grad.name
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
262
        return params_and_grads
F
fengjiayi 已提交
263

Y
yuyang18 已提交
264
    def _process_context(self, context, param, grad):
Y
Yu Yang 已提交
265 266
        pass

Y
yuyang18 已提交
267
    def _create_operators(self, param, grad):
Y
Yu Yang 已提交
268 269 270 271
        new_grad = layers.clip(x=grad, min=self.min, max=self.max)
        return param, new_grad


272
class ClipGradByNorm(ClipGradBase):
273
    r"""
274 275 276 277 278 279
    Limit the l2 norm of multi-dimensional Tensor :math:`X` to ``clip_norm`` .
    
    - If the l2 norm of :math:`X` is greater than ``clip_norm`` , :math:`X` will be compressed by a ratio.
    
    - If the l2 norm of :math:`X` is less than or equal to ``clip_norm`` , nothing will be done.
    
280 281
    The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
282
    
283
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
284
    (for example: :ref:`api_paddle_optimizer_SGD`).
285 286
    
    The clipping formula is:
287 288

    .. math::
289
        Out =
290 291 292 293 294 295
        \left\{
            \begin{array}{ccl}
                X & & if (norm(X) \leq clip\_norm) \\
                \frac{clip\_norm*X}{norm(X)} & & if (norm(X) > clip\_norm) \\
        \end{array}
        \right.
296 297 298 299


    where :math:`norm(X)` represents the L2 norm of :math:`X`.

300
    .. math::
301
        norm(X) = ( \sum_{i=1}^{n}|x\_i|^2)^{ \frac{1}{2}}
302

303 304 305 306
    Note:
        ``need_clip`` of ``ClipGradByNorm`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

307
    Args:
308
        clip_norm(float): The maximum norm value.
C
Chengmo 已提交
309

310 311
    Examples:
        .. code-block:: python
312 313
        
            import paddle
314

315
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
316 317 318
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
319 320 321 322
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

323
            clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
324 325
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
326 327
    """

328 329
    def __init__(self, clip_norm):
        super(ClipGradByNorm, self).__init__()
330
        self.clip_norm = float(clip_norm)
F
fengjiayi 已提交
331

F
fengjiayi 已提交
332
    def __str__(self):
333 334
        return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm

335
    @imperative_base.no_grad
336 337 338 339 340
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        for p, g in params_grads:
            if g is None:
                continue
341
            if getattr(p, 'need_clip', True) is False:
342 343 344 345 346 347 348 349 350
                params_and_grads.append((p, g))
                continue
            new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
            params_and_grads.append((p, new_grad))
        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
        with framework.name_scope('gradient_clip'):
351
            param_new_grad_name_dict = dict()
352 353 354
            for p, g in params_grads:
                if g is None:
                    continue
355
                if getattr(p, 'need_clip', True) is False:
356 357 358 359 360
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
                    new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
361
                param_new_grad_name_dict[p.name] = new_grad.name
362
                params_and_grads.append((p, new_grad))
363
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
364
        return params_and_grads
F
fengjiayi 已提交
365

Y
yuyang18 已提交
366
    def _process_context(self, context, param, grad):
F
fengjiayi 已提交
367 368
        pass

Y
yuyang18 已提交
369
    def _create_operators(self, param, grad):
F
fengjiayi 已提交
370 371 372 373
        new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
        return param, new_grad


374
class ClipGradByGlobalNorm(ClipGradBase):
375
    r"""
376 377 378 379 380 381 382
    Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in 
    :math:`t\_list` , and limit it to ``clip_norm`` .
    
    - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
    
    - If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
    
383 384
    The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
385
    
386
    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` 
387
    (for example: :ref:`api_paddle_optimizer_SGD`).
388 389

    The clipping formula is:
390 391 392

    .. math::

393
        t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}
394 395 396 397 398 399 400

    where:

    .. math::

        global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}

401 402 403 404
    Note:
        ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0. 
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

405
    Args:
406
        clip_norm (float): The maximum norm value.
407
        group_name (str, optional): The group name for this clip. Default value is ``default_group``.
408 409 410

    Examples:
        .. code-block:: python
411
        
412 413
            import paddle

414
            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
415 416 417
            linear = paddle.nn.Linear(in_features=10, out_features=10, 
                                      weight_attr=paddle.ParamAttr(need_clip=True), 
                                      bias_attr=paddle.ParamAttr(need_clip=False))
418 419 420 421
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

422
            clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
423 424
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
425 426
    """

427 428
    def __init__(self, clip_norm, group_name="default_group"):
        super(ClipGradByGlobalNorm, self).__init__()
429
        self.clip_norm = float(clip_norm)
F
update  
fengjiayi 已提交
430
        self.group_name = group_name
431

F
fengjiayi 已提交
432
    def __str__(self):
433 434
        return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)

435
    @imperative_base.no_grad
436 437 438 439 440 441
    def _dygraph_clip(self, params_grads):
        params_and_grads = []
        sum_square_list = []
        for p, g in params_grads:
            if g is None:
                continue
442
            if getattr(p, 'need_clip', True) is False:
443 444 445 446 447
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.merge_selected_rows(g)
                merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
W
WangXi 已提交
448 449

            sum_square = _squared_l2_norm(merge_grad)
450 451 452 453 454 455 456 457 458 459
            sum_square_list.append(sum_square)

        # all parameters have been filterd out
        if len(sum_square_list) == 0:
            return params_grads

        global_norm_var = layers.concat(sum_square_list)
        global_norm_var = layers.reduce_sum(global_norm_var)
        global_norm_var = layers.sqrt(global_norm_var)
        max_global_norm = layers.fill_constant(
460
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
461 462 463 464 465 466 467
        clip_var = layers.elementwise_div(
            x=max_global_norm,
            y=layers.elementwise_max(
                x=global_norm_var, y=max_global_norm))
        for p, g in params_grads:
            if g is None:
                continue
468
            if getattr(p, 'need_clip', True) is False:
469 470
                params_and_grads.append((p, g))
                continue
W
WangXi 已提交
471
            # TODO(wangxi): use inplace elementwise_mul
472 473 474 475 476 477 478 479
            new_grad = layers.elementwise_mul(x=g, y=clip_var)
            params_and_grads.append((p, new_grad))

        return params_and_grads

    def _static_clip(self, params_grads):
        params_and_grads = []
        sum_square_list = []
480 481
        sum_square_list_fp16 = []
        sum_square_list_fp32 = []
482 483 484 485
        with framework.name_scope('gradient_clip'):
            for p, g in params_grads:
                if g is None:
                    continue
486
                if getattr(p, 'need_clip', True) is False:
487 488 489 490 491 492 493
                    continue
                merge_grad = g
                with p.block.program._optimized_guard([p, g]):
                    if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                        merge_grad = layers.merge_selected_rows(g)
                        merge_grad = layers.get_tensor_from_selected_rows(
                            merge_grad)
W
WangXi 已提交
494
                    sum_square = _squared_l2_norm(merge_grad)
495 496 497 498 499 500
                    if sum_square.dtype == core.VarDesc.VarType.FP16:
                        sum_square_list_fp16.append(sum_square)
                    elif sum_square.dtype == core.VarDesc.VarType.FP32:
                        sum_square_list_fp32.append(sum_square)
                    else:
                        sum_square_list.append(sum_square)
501 502

            # all parameters have been filterd out
503 504
            if len(sum_square_list) + len(sum_square_list_fp16) + len(
                    sum_square_list_fp32) == 0:
505 506 507
                return params_grads

            with p.block.program._optimized_guard([p, g]):
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
                sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"

                global_norm_var = []
                if len(sum_square_list_fp16) > 0:
                    global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
                    global_norm_var.append(
                        global_norm_var_fp16.astype(sum_dtype))
                if len(sum_square_list_fp32) > 0:
                    global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
                    if sum_dtype == 'float32':
                        global_norm_var.append(global_norm_var_fp32)
                    else:
                        global_norm_var.append(
                            global_norm_var_fp32.astype(sum_dtype))
                if len(sum_square_list) > 0:
                    # fp64
                    global_norm_var_other_dtype = layers.sums(sum_square_list)
                    global_norm_var.append(global_norm_var_other_dtype)
526 527 528

                global_norm_var = layers.sums(global_norm_var) if len(
                    global_norm_var) > 1 else global_norm_var[0]
529 530
                global_norm_var = layers.sqrt(x=global_norm_var)
                max_global_norm = layers.fill_constant(
531 532 533
                    shape=[1],
                    dtype=global_norm_var.dtype,
                    value=self.clip_norm)
534 535 536 537
                scale_var = layers.elementwise_div(
                    x=max_global_norm,
                    y=layers.elementwise_max(
                        x=max_global_norm, y=global_norm_var))
538
            param_new_grad_name_dict = dict()
539 540 541
            for p, g in params_grads:
                if g is None:
                    continue
542
                if getattr(p, 'need_clip', True) is False:
543 544 545 546
                    params_and_grads.append((p, g))
                    continue

                with p.block.program._optimized_guard([p, g]):
W
WangXi 已提交
547
                    # inplace
548 549 550
                    scale_input = (scale_var.astype('float16')
                                   if g.dtype == core.VarDesc.VarType.FP16 else
                                   scale_var)
551 552 553 554 555 556
                    # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
                    # will be in different blocks with the gradient clip related ops.
                    # We need to handle the correct block, otherwise will encounter
                    # a 'NotFoundError' during compile time.
                    block = default_main_program().current_block()
                    block.append_op(
W
WangXi 已提交
557 558
                        type='elementwise_mul',
                        inputs={'X': g,
559
                                'Y': scale_input},
W
WangXi 已提交
560
                        outputs={'Out': g})
561

W
WangXi 已提交
562 563
                param_new_grad_name_dict[p.name] = g.name
                params_and_grads.append((p, g))
564

565
        _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
566
        return params_and_grads
F
fengjiayi 已提交
567

Y
yuyang18 已提交
568
    def _process_context(self, context, param, grad):
F
update  
fengjiayi 已提交
569 570 571 572
        if self.group_name not in context:
            context[self.group_name] = []
            context[self.group_name + "_clip_value"] = self.clip_norm
            context[self.group_name + "_clip"] = layers.fill_constant(
573
                shape=[1], dtype=grad.dtype, value=self.clip_norm)
F
update  
fengjiayi 已提交
574 575 576 577 578
        else:
            if not self.clip_norm == context[self.group_name + "_clip_value"]:
                raise ValueError(
                    "All parameters' 'clip_norm' of a same group should be the same"
                )
F
fengjiayi 已提交
579

C
chengduo 已提交
580 581 582 583 584
        merge_grad = grad
        if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
            merge_grad = layers.merge_selected_rows(grad)
            merge_grad = layers.get_tensor_from_selected_rows(merge_grad)

W
WangXi 已提交
585
        local_norm_var = _squared_l2_norm(merge_grad)
F
update  
fengjiayi 已提交
586
        context[self.group_name].append(local_norm_var)
F
fengjiayi 已提交
587

F
update  
fengjiayi 已提交
588
        self.context = context
589

Y
yuyang18 已提交
590
    def _create_operators(self, param, grad):
F
update  
fengjiayi 已提交
591 592 593
        group_scale_name = self.group_name + "_scale"
        if group_scale_name not in self.context:
            group_norm_var = layers.sums(input=self.context[self.group_name])
T
tensor-tang 已提交
594
            group_norm_var = layers.sqrt(x=group_norm_var)
F
update  
fengjiayi 已提交
595 596 597
            clip_var = self.context[self.group_name + "_clip"]
            group_scale_var = layers.elementwise_div(
                x=clip_var,
F
fengjiayi 已提交
598
                y=layers.elementwise_max(
F
update  
fengjiayi 已提交
599
                    x=clip_var, y=group_norm_var))
600
            assert group_scale_var.shape == (1, )
F
update  
fengjiayi 已提交
601
            self.context[group_scale_name] = group_scale_var
F
fengjiayi 已提交
602

W
WangXi 已提交
603 604 605 606 607 608
        # inplace
        param.block.append_op(
            type='elementwise_mul',
            inputs={'X': grad,
                    'Y': self.context[group_scale_name]},
            outputs={'Out': grad})
C
chengduo 已提交
609

W
WangXi 已提交
610
        return param, grad
F
fengjiayi 已提交
611 612


613
@framework.dygraph_not_support
F
fengjiayi 已提交
614
def set_gradient_clip(clip, param_list=None, program=None):
F
fengjiayi 已提交
615
    """
616 617
    :api_attr: Static Graph
    
618 619 620 621
    Warning:
    
        This API must be used after building network, and before ``minimize`` , 
        and it may be removed in future releases, so it is not recommended. 
622 623 624 625
        It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
        this is a better method to clip gradient. There are three clipping strategies:
         :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , 
         :ref:`api_fluid_clip_GradientClipByValue` .
626
        
627 628 629
    To specify parameters that require gradient clip.

    Args:
630 631 632 633 634
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of 
            some derived class of ``GradientClipBase`` . There are three cliping strategies 
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , 
            :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no 
            gradient clipping.
Z
Zeng Jinle 已提交
635
        param_list (list(Variable), optional): Parameters that require gradient clip.
636
                It can be a list of parameter or a list of parameter's name.
637
                Default None, meaning that all parameters in the program will be included.
Z
Zeng Jinle 已提交
638
        program (Program, optional): The program where parameters are located.
639 640 641 642 643 644 645
                Default None, meaning that using :ref:`api_fluid_default_main_program` .

    Returns:
        None

    Examples:
        .. code-block:: python
C
Chengmo 已提交
646

647 648 649
            import paddle.fluid as fluid

            def network():
C
Chengmo 已提交
650 651
                image = fluid.data(name='image', shape=[
                                   None, 28], dtype='float32')
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
                param_attr1 = fluid.ParamAttr("fc1_param")
                fc1 = fluid.layers.fc(image, size=10, param_attr=param_attr1)
                param_attr2 = fluid.ParamAttr("fc2_param")
                fc2 = fluid.layers.fc(fc1, size=10, param_attr=param_attr2)
                loss = fluid.layers.reduce_mean(fc2)
                return loss


            # network 1: clip all parameter gradient
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

            # network 2: clip parameter gradient by name
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=["fc1_param", "fc2_param"])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)

677
            # network 3: clip parameter gradient by value
678 679 680 681 682 683 684 685 686
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                param_var1 = fluid.default_main_program().global_block().var("fc1_param")
                param_var2 = fluid.default_main_program().global_block().var("fc2_param")
                fluid.clip.set_gradient_clip(
                    fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
                    param_list=[param_var1, param_var2])
                sgd = fluid.optimizer.SGD(learning_rate=1e-3)
                sgd.minimize(loss)
687
            
688
            # network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
689 690 691 692 693 694 695
            with fluid.program_guard(fluid.Program(), fluid.Program()):
                loss = network()
                clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0)
                clip2 = fluid.clip.GradientClipByNorm(clip_norm=1.0)
                # Set the gradient clipping strategy: clip1
                fluid.clip.set_gradient_clip(clip1)
                # Set the gradient clipping strategy: clip2
696 697
                sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
                sgd.minimize(loss)
698 699 700 701
                # 'set_gradient_clip' will not take effect when setting has a conflict, 
                # and the gradient clipping strategy will be 'clip2'
            
            
F
fengjiayi 已提交
702
    """
703 704
    warnings.warn("Caution! 'set_gradient_clip' is not recommended "
                  "and may be deprecated in future! "
705 706
                  "We recommend a new strategy: set 'grad_clip' "
                  "when initializing the 'optimizer'. "
707
                  "This method can reduce the mistakes, please "
708
                  "refer to documention of 'optimizer'.")
709

710
    if not isinstance(clip, ClipGradBase):
F
fengjiayi 已提交
711
        raise TypeError(
712
            "'clip' should be an instance of ClipGradBase's derived class")
F
fengjiayi 已提交
713 714
    if program is None:
        program = framework.default_main_program()
715 716 717 718 719 720 721 722 723 724

    for op in program.block(0).ops:
        if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
                "op_namescope"):
            warnings.warn(
                "'minimize' has been invoked before, this will make 'set_gradient_clip' "
                "be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
            )
            break

F
fengjiayi 已提交
725 726
    if param_list is None:
        param_list = program.block(0).all_parameters()
727
    if all(isinstance(elem, six.string_types) for elem in param_list):
F
fengjiayi 已提交
728 729 730 731 732 733 734
        param_list = [program.block(0).var(elem) for elem in param_list]
    if not all(isinstance(elem, framework.Parameter) for elem in param_list):
        raise TypeError(
            "'param_list' should be a list of Parameter or basestring(parameter's name)."
        )

    for param in param_list:
F
fengjiayi 已提交
735
        param.gradient_clip_attr = copy.deepcopy(clip)
F
fengjiayi 已提交
736 737


738
def append_gradient_clip_ops(param_grads):
Y
Yu Yang 已提交
739
    context = dict()
740 741 742
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
743
        with p.block.program._optimized_guard(
744
            [p, g]), framework.name_scope('gradient_clip'):
745
            clip_attr = getattr(p, 'gradient_clip_attr', None)
Y
yuyang18 已提交
746
            if clip_attr is None:
747
                return param_grads
748
            if not isinstance(clip_attr, ClipGradBase):
Y
yuyang18 已提交
749
                raise TypeError(
750
                    "clip attribute should be an instance of GradientClipBase")
Y
Yu Yang 已提交
751

Y
yuyang18 已提交
752
            clip_attr._process_context(context=context, param=p, grad=g)
Y
yuyang18 已提交
753 754

    res = []
755
    param_new_grad_name_dict = dict()
756 757 758
    for p, g in param_grads:
        if g is None:
            continue
X
Xin Pan 已提交
759
        with p.block.program._optimized_guard(
760
            [p, g]), framework.name_scope('gradient_clip'):
761
            param, new_grad = clip_attr._create_operators(param=p, grad=g)
762
            param_new_grad_name_dict[param.name] = new_grad.name
763
            res.append([param, new_grad])
Y
Yu Yang 已提交
764

765
    _correct_clip_op_role_var(res, param_new_grad_name_dict)
766 767 768 769
    return res


# change wrong mapping relation between param & grad in clip op
770 771
# Note: This function is sensitive to the time cost of the network with gradient clipping 
# and should not be changed easily. If you must change, please test the time cost.
772 773 774 775
def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
    block_id_list = []
    if len(param_new_grad_name_dict) == 0:
        return
776 777
    for param, grad in params_grads:
        if grad is None:
778
            continue
779 780 781 782
        block_id = param.block.idx
        if block_id in block_id_list:
            continue
        block_id_list.append(block_id)
783
        for op in param.block.program.global_block().ops:
W
WangXi 已提交
784
            if op.has_attr("op_namescope") and "gradient_clip" in op.attr(
785 786 787 788 789 790
                    "op_namescope") and op.attr('op_role_var'):
                param_name = op.attr('op_role_var')[0]
                if param_name in param_new_grad_name_dict:
                    correct_p_g = [
                        param_name, param_new_grad_name_dict[param_name]
                    ]
C
Chengmo 已提交
791
                    op._set_attr('op_role_var', correct_p_g)
Y
Yu Yang 已提交
792 793


794 795 796 797
GradientClipBase = ClipGradBase
GradientClipByValue = ClipGradByValue
GradientClipByNorm = ClipGradByNorm
GradientClipByGlobalNorm = ClipGradByGlobalNorm