grad_clip.py 9.3 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
HongyuJia 已提交
15
import paddle
16
import paddle.distributed as dist
W
wuhuachaocoding 已提交
17 18
from paddle.autograd import no_grad
from paddle.framework import core
19 20
from paddle.nn import clip
from paddle.nn.clip import ClipGradBase, _squared_l2_norm
R
Roc 已提交
21 22 23 24


class ClipGradForMOEByGlobalNorm(ClipGradBase):
    r"""
25
    The Algrithm is the same as paddle.nn.ClipGradByGlobalNorm
26
    Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
R
Roc 已提交
27
    :math:`t\_list` , and limit it to ``clip_norm`` .
28

R
Roc 已提交
29
    - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
30

R
Roc 已提交
31
    - If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
32

R
Roc 已提交
33 34
    The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
35 36

    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
R
Roc 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    (for example: :ref:`api_paddle_optimizer_SGD`).

    The clipping formula is:

    .. math::

        t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}

    where:

    .. math::

        global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}

    Note:
52
        ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0.
R
Roc 已提交
53 54
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

R
Roc 已提交
55 56 57 58 59
    Reference:
        https://github.com/laekov/fastmoe/blob/master/examples/megatron/clip-grad-v2.2.patch
        Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4


R
Roc 已提交
60 61 62 63 64 65 66 67
    Args:
        clip_norm (float): The maximum norm value.
        is_expert_param_func (function): a function to decide whether a param should be put into moe_params_grads
        moe_group (Group): group for moe experts communication.
        group_name (str, optional): The group name for this clip. Default value is ``default_moe_group``.

    Examples:
        .. code-block:: python
68

R
Roc 已提交
69 70 71
            import paddle

            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
72 73
            linear = paddle.nn.Linear(in_features=10, out_features=10,
                                      weight_attr=paddle.ParamAttr(need_clip=True),
R
Roc 已提交
74 75 76 77 78 79 80 81 82 83 84
                                      bias_attr=paddle.ParamAttr(need_clip=False))
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

            is_expert_func = lambda param: "expert_" in param.name
            clip = paddle.nn.ClipGradForMOEByGlobalNorm(clip_norm=1.0,is_expert_func, None)
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
    """

85 86 87 88 89 90 91
    def __init__(
        self,
        clip_norm,
        is_expert_param_func=None,
        moe_group=None,
        group_name="default_moe_group",
    ):
92
        super().__init__()
R
Roc 已提交
93 94 95 96
        self.clip_norm = float(clip_norm)
        self.group_name = group_name
        self.moe_group = moe_group
        if moe_group is not None and moe_group.nranks > 1:
97 98 99
            assert (
                is_expert_param_func is not None
            ), "When moe group size > 1, a function for selecting expert params must be specified."
R
Roc 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        self.is_expert_param_func = is_expert_param_func

    def __str__(self):
        return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)

    @staticmethod
    def get_l2_norm_pow(params_grads, sum_dtype=None):
        sum_square_list = []
        sum_square_list_fp16 = []
        sum_square_list_fp32 = []
        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
117 118
                merge_grad = clip.merge_selected_rows(g)
                merge_grad = clip.get_tensor_from_selected_rows(merge_grad)
R
Roc 已提交
119 120 121 122 123 124 125 126 127
            sum_square = _squared_l2_norm(merge_grad)
            if sum_square.dtype == core.VarDesc.VarType.FP16:
                sum_square_list_fp16.append(sum_square)
            elif sum_square.dtype == core.VarDesc.VarType.FP32:
                sum_square_list_fp32.append(sum_square)
            else:
                sum_square_list.append(sum_square)

        # all parameters have been filterd out
128 129 130 131 132 133
        if (
            len(sum_square_list)
            + len(sum_square_list_fp16)
            + len(sum_square_list_fp32)
            == 0
        ):
R
Roc 已提交
134
            return None, None
135 136 137 138 139
        assert sum_dtype in [
            "float64",
            "float32",
            None,
        ], "sum's type must be float64/ float32 / None"
R
Roc 已提交
140 141 142 143 144
        if sum_dtype != "float64":
            sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"

        global_norm_var = []
        if len(sum_square_list_fp16) > 0:
W
wuhuachaocoding 已提交
145
            global_norm_var_fp16 = paddle.concat(sum_square_list_fp16)
146
            global_norm_var_fp16 = paddle.sum(global_norm_var_fp16)
R
Roc 已提交
147 148
            global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
        if len(sum_square_list_fp32) > 0:
W
wuhuachaocoding 已提交
149
            global_norm_var_fp32 = paddle.concat(sum_square_list_fp32)
150
            global_norm_var_fp32 = paddle.sum(global_norm_var_fp32)
R
Roc 已提交
151 152 153 154 155
            if sum_dtype == 'float32':
                global_norm_var.append(global_norm_var_fp32)
            else:
                global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
        if len(sum_square_list) > 0:
W
wuhuachaocoding 已提交
156
            global_norm_var_fp64 = paddle.concat(sum_square_list)
157
            global_norm_var_fp64 = paddle.sum(global_norm_var_fp64)
R
Roc 已提交
158
            global_norm_var.append(global_norm_var_fp64)
W
wuhuachaocoding 已提交
159
        global_norm_var = paddle.concat(global_norm_var)
160
        global_norm_var = paddle.sum(global_norm_var)
R
Roc 已提交
161 162
        return global_norm_var, sum_dtype

W
wuhuachaocoding 已提交
163
    @no_grad()
R
Roc 已提交
164 165 166 167
    def _dygraph_clip(self, params_grads):
        normal_params_grads = []
        moe_params_grads = []

168
        # separate moe params from normal params
R
Roc 已提交
169 170 171 172 173 174 175 176 177 178 179 180
        if self.moe_group is not None and self.moe_group.nranks > 1:
            for p, g in params_grads:
                if self.is_expert_param_func(p):
                    moe_params_grads.append((p, g))
                else:
                    normal_params_grads.append((p, g))
        else:
            normal_params_grads = params_grads

        # why to return sum_dtype?
        # we will call `get_l2_norm_pow` twice and the precisions may be different.
        # For convenience and simplification, we use sum_dtype directly instead of global_norm_var_normal.dtype
181 182 183
        global_norm_var_normal, sum_dtype = self.get_l2_norm_pow(
            normal_params_grads
        )
R
Roc 已提交
184 185
        global_norm_var_moe = None
        if len(moe_params_grads) > 0:
186 187 188
            global_norm_var_moe, _ = self.get_l2_norm_pow(
                moe_params_grads, sum_dtype
            )
R
Roc 已提交
189
            if global_norm_var_moe is not None:
190
                dist.all_reduce(
191
                    global_norm_var_moe,
192
                    op=dist.ReduceOp.SUM,
193 194
                    group=self.moe_group,
                )
R
Roc 已提交
195 196 197 198 199 200 201 202 203 204 205

        if global_norm_var_normal is None and global_norm_var_moe is None:
            return params_grads
        elif global_norm_var_normal is None:
            global_norm_var = global_norm_var_moe
        elif global_norm_var_moe is None:
            global_norm_var = global_norm_var_normal
        else:
            if global_norm_var_normal.dtype != global_norm_var_moe.dtype:
                # compared with normal norm, moe norm is the later one,
                # so its precision is no lower than normal norm
206 207 208
                global_norm_var_normal = global_norm_var_normal.astype(
                    global_norm_var_moe.dtype
                )
R
Roc 已提交
209 210 211
            global_norm_var = global_norm_var_normal + global_norm_var_moe

        params_and_grads = []
212
        global_norm_var = paddle.sqrt(global_norm_var)
W
wuhuachaocoding 已提交
213 214
        max_global_norm = paddle.full(
            shape=[1], dtype=global_norm_var.dtype, fill_value=self.clip_norm
215
        )
216
        clip_var = paddle.divide(
217
            x=max_global_norm,
H
HongyuJia 已提交
218
            y=paddle.maximum(x=global_norm_var, y=max_global_norm),
219
        )
R
Roc 已提交
220 221 222 223 224 225 226
        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                params_and_grads.append((p, g))
                continue
            # TODO(wangxi): use inplace elementwise_mul
227 228 229 230 231
            clip_input = (
                clip_var.astype('float16')
                if g.dtype == core.VarDesc.VarType.FP16
                else clip_var
            )
232
            new_grad = paddle.multiply(x=g, y=clip_input)
R
Roc 已提交
233 234 235 236 237
            params_and_grads.append((p, new_grad))
        return params_and_grads


ClipGradByGlobalNorm = ClipGradForMOEByGlobalNorm