grad_clip.py 9.4 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
HongyuJia 已提交
15
import paddle
16
import paddle.distributed as dist
17
from paddle.fluid import core, layers
R
Roc 已提交
18 19 20 21 22 23 24
from paddle.fluid.clip import ClipGradBase, _squared_l2_norm
from paddle.fluid.dygraph import base as imperative_base


class ClipGradForMOEByGlobalNorm(ClipGradBase):
    r"""
    The Algrithm is the same as paddle.fluid.clip.ClipGradByGlobalNorm
25
    Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
R
Roc 已提交
26
    :math:`t\_list` , and limit it to ``clip_norm`` .
27

R
Roc 已提交
28
    - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
29

R
Roc 已提交
30
    - If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
31

R
Roc 已提交
32 33
    The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
    If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
34 35

    Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
R
Roc 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    (for example: :ref:`api_paddle_optimizer_SGD`).

    The clipping formula is:

    .. math::

        t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}

    where:

    .. math::

        global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}

    Note:
51
        ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0.
R
Roc 已提交
52 53
        Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.

R
Roc 已提交
54 55 56 57 58
    Reference:
        https://github.com/laekov/fastmoe/blob/master/examples/megatron/clip-grad-v2.2.patch
        Git commit hash: 295a615aacce7e54a37e7935274ba15e901c78e4


R
Roc 已提交
59 60 61 62 63 64 65 66
    Args:
        clip_norm (float): The maximum norm value.
        is_expert_param_func (function): a function to decide whether a param should be put into moe_params_grads
        moe_group (Group): group for moe experts communication.
        group_name (str, optional): The group name for this clip. Default value is ``default_moe_group``.

    Examples:
        .. code-block:: python
67

R
Roc 已提交
68 69 70
            import paddle

            x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
71 72
            linear = paddle.nn.Linear(in_features=10, out_features=10,
                                      weight_attr=paddle.ParamAttr(need_clip=True),
R
Roc 已提交
73 74 75 76 77 78 79 80 81 82 83
                                      bias_attr=paddle.ParamAttr(need_clip=False))
            out = linear(x)
            loss = paddle.mean(out)
            loss.backward()

            is_expert_func = lambda param: "expert_" in param.name
            clip = paddle.nn.ClipGradForMOEByGlobalNorm(clip_norm=1.0,is_expert_func, None)
            sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
            sdg.step()
    """

84 85 86 87 88 89 90
    def __init__(
        self,
        clip_norm,
        is_expert_param_func=None,
        moe_group=None,
        group_name="default_moe_group",
    ):
91
        super().__init__()
R
Roc 已提交
92 93 94 95
        self.clip_norm = float(clip_norm)
        self.group_name = group_name
        self.moe_group = moe_group
        if moe_group is not None and moe_group.nranks > 1:
96 97 98
            assert (
                is_expert_param_func is not None
            ), "When moe group size > 1, a function for selecting expert params must be specified."
R
Roc 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        self.is_expert_param_func = is_expert_param_func

    def __str__(self):
        return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)

    @staticmethod
    def get_l2_norm_pow(params_grads, sum_dtype=None):
        sum_square_list = []
        sum_square_list_fp16 = []
        sum_square_list_fp32 = []
        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                continue
            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.merge_selected_rows(g)
                merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
            sum_square = _squared_l2_norm(merge_grad)
            if sum_square.dtype == core.VarDesc.VarType.FP16:
                sum_square_list_fp16.append(sum_square)
            elif sum_square.dtype == core.VarDesc.VarType.FP32:
                sum_square_list_fp32.append(sum_square)
            else:
                sum_square_list.append(sum_square)

        # all parameters have been filterd out
127 128 129 130 131 132
        if (
            len(sum_square_list)
            + len(sum_square_list_fp16)
            + len(sum_square_list_fp32)
            == 0
        ):
R
Roc 已提交
133
            return None, None
134 135 136 137 138
        assert sum_dtype in [
            "float64",
            "float32",
            None,
        ], "sum's type must be float64/ float32 / None"
R
Roc 已提交
139 140 141 142 143 144
        if sum_dtype != "float64":
            sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"

        global_norm_var = []
        if len(sum_square_list_fp16) > 0:
            global_norm_var_fp16 = layers.concat(sum_square_list_fp16)
145
            global_norm_var_fp16 = paddle.sum(global_norm_var_fp16)
R
Roc 已提交
146 147 148
            global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
        if len(sum_square_list_fp32) > 0:
            global_norm_var_fp32 = layers.concat(sum_square_list_fp32)
149
            global_norm_var_fp32 = paddle.sum(global_norm_var_fp32)
R
Roc 已提交
150 151 152 153 154 155
            if sum_dtype == 'float32':
                global_norm_var.append(global_norm_var_fp32)
            else:
                global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
        if len(sum_square_list) > 0:
            global_norm_var_fp64 = layers.concat(sum_square_list)
156
            global_norm_var_fp64 = paddle.sum(global_norm_var_fp64)
R
Roc 已提交
157 158
            global_norm_var.append(global_norm_var_fp64)
        global_norm_var = layers.concat(global_norm_var)
159
        global_norm_var = paddle.sum(global_norm_var)
R
Roc 已提交
160 161 162 163 164 165 166
        return global_norm_var, sum_dtype

    @imperative_base.no_grad
    def _dygraph_clip(self, params_grads):
        normal_params_grads = []
        moe_params_grads = []

167
        # separate moe params from normal params
R
Roc 已提交
168 169 170 171 172 173 174 175 176 177 178 179
        if self.moe_group is not None and self.moe_group.nranks > 1:
            for p, g in params_grads:
                if self.is_expert_param_func(p):
                    moe_params_grads.append((p, g))
                else:
                    normal_params_grads.append((p, g))
        else:
            normal_params_grads = params_grads

        # why to return sum_dtype?
        # we will call `get_l2_norm_pow` twice and the precisions may be different.
        # For convenience and simplification, we use sum_dtype directly instead of global_norm_var_normal.dtype
180 181 182
        global_norm_var_normal, sum_dtype = self.get_l2_norm_pow(
            normal_params_grads
        )
R
Roc 已提交
183 184
        global_norm_var_moe = None
        if len(moe_params_grads) > 0:
185 186 187
            global_norm_var_moe, _ = self.get_l2_norm_pow(
                moe_params_grads, sum_dtype
            )
R
Roc 已提交
188
            if global_norm_var_moe is not None:
189
                dist.all_reduce(
190
                    global_norm_var_moe,
191
                    op=dist.ReduceOp.SUM,
192 193
                    group=self.moe_group,
                )
R
Roc 已提交
194 195 196 197 198 199 200 201 202 203 204

        if global_norm_var_normal is None and global_norm_var_moe is None:
            return params_grads
        elif global_norm_var_normal is None:
            global_norm_var = global_norm_var_moe
        elif global_norm_var_moe is None:
            global_norm_var = global_norm_var_normal
        else:
            if global_norm_var_normal.dtype != global_norm_var_moe.dtype:
                # compared with normal norm, moe norm is the later one,
                # so its precision is no lower than normal norm
205 206 207
                global_norm_var_normal = global_norm_var_normal.astype(
                    global_norm_var_moe.dtype
                )
R
Roc 已提交
208 209 210
            global_norm_var = global_norm_var_normal + global_norm_var_moe

        params_and_grads = []
211
        global_norm_var = paddle.sqrt(global_norm_var)
212 213 214 215 216
        max_global_norm = layers.fill_constant(
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
        )
        clip_var = layers.elementwise_div(
            x=max_global_norm,
H
HongyuJia 已提交
217
            y=paddle.maximum(x=global_norm_var, y=max_global_norm),
218
        )
R
Roc 已提交
219 220 221 222 223 224 225
        for p, g in params_grads:
            if g is None:
                continue
            if getattr(p, 'need_clip', True) is False:
                params_and_grads.append((p, g))
                continue
            # TODO(wangxi): use inplace elementwise_mul
226 227 228 229 230
            clip_input = (
                clip_var.astype('float16')
                if g.dtype == core.VarDesc.VarType.FP16
                else clip_var
            )
R
Roc 已提交
231 232 233 234 235 236
            new_grad = layers.elementwise_mul(x=g, y=clip_input)
            params_and_grads.append((p, new_grad))
        return params_and_grads


ClipGradByGlobalNorm = ClipGradForMOEByGlobalNorm