sharding_utils.py 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from enum import Enum
17
from types import MethodType
18

19 20
import numpy as np

21
import paddle
22
from paddle import _legacy_C_ops
23 24
from paddle.fluid import core, layers
from paddle.fluid.dygraph import base as imperative_base
25 26
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42


class Taskflow:
    """
    Task flows, one way linked list for task acquisition.
    """

    def __init__(self, task, callback):
        self.task = task
        self.callback = callback


class Type(Enum):
    """
    Type of trainable parameters
    """
43

44
    fp16 = paddle.float16
45
    bf16 = paddle.bfloat16
46 47 48
    fp32 = paddle.float32


49
class ShardingClipGrad:
50
    def __init__(self, clip, device, group):
51 52
        self._clip = clip
        self._device = device
53
        self._group = group
54 55 56

    @imperative_base.no_grad
    def _dygraph_clip(self, params_grads):
B
Baibaifan 已提交
57 58
        sum_square_fp32, sum_square_fp16 = [], []
        unslice_params_fp32, unslice_params_fp16 = [], []
59 60

        for p, g in params_grads:
B
Baibaifan 已提交
61
            p_slice = True  # using for slice parameter in sharding stage3
62 63
            if g is None or getattr(p, 'need_clip', True) is False:
                continue
B
Baibaifan 已提交
64 65
            if hasattr(p, "unslice"):
                p_slice = False
66 67 68 69

            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.get_tensor_from_selected_rows(
70 71
                    layers.merge_selected_rows(g)
                )
72
            square = paddle.square(merge_grad)
73
            sum_square = paddle.sum(square)
74
            if p.dtype == paddle.float16:
75 76 77 78
                if p_slice:
                    sum_square_fp16.append(sum_square)
                else:
                    unslice_params_fp16.append(sum_square)
79
            elif p.dtype == paddle.float32:
80 81 82 83
                if p_slice:
                    sum_square_fp32.append(sum_square)
                else:
                    unslice_params_fp32.append(sum_square)
84 85 86

        # global norm of non-distributed FP16 params_and_grads
        if len(sum_square_fp16) == 0:
87
            global_norm_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
88 89
        else:
            global_norm_fp16 = layers.concat(sum_square_fp16)
90
            global_norm_fp16 = paddle.sum(global_norm_fp16)
91 92 93
            global_norm_fp16 = paddle.cast(
                global_norm_fp16, dtype=paddle.float32
            )
94

B
Baibaifan 已提交
95
        # global norm of non-distributed FP16 params_and_grads for unslice parameter
B
Baibaifan 已提交
96
        if len(unslice_params_fp16) == 0:
97
            global_unslice_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
B
Baibaifan 已提交
98 99
        else:
            global_unslice_fp16 = layers.concat(unslice_params_fp16)
100
            global_unslice_fp16 = paddle.sum(global_unslice_fp16)
101 102 103
            global_unslice_fp16 = paddle.cast(
                global_unslice_fp16, dtype=paddle.float32
            )
B
Baibaifan 已提交
104

105
        # global norm of non-distributed FP32 params_and_grads
106 107 108 109 110
        global_norm_fp32 = (
            layers.concat(sum_square_fp32)
            if len(sum_square_fp32) != 0
            else paddle.to_tensor([0.0], dtype=paddle.float32)
        )
111
        global_norm_fp32 = paddle.sum(global_norm_fp32)
112

B
Baibaifan 已提交
113
        # global norm of non-distributed FP32 params_and_grads for unslice parameter
114 115 116 117 118
        global_unslice_fp32 = (
            layers.concat(unslice_params_fp32)
            if len(unslice_params_fp32) != 0
            else paddle.to_tensor([0.0], dtype=paddle.float32)
        )
119
        global_unslice_fp32 = paddle.sum(global_unslice_fp32)
B
Baibaifan 已提交
120 121
        global_unslice_var = global_unslice_fp16 + global_unslice_fp32

122 123 124 125 126
        global_norm_var = (
            global_norm_fp16
            + global_norm_fp32
            + 1.0 / self._group.nranks * global_unslice_var
        )
127 128 129 130 131 132

        # add all reduce to get global norm of distributed params_and_grads
        dev_id = int(self._device.split(":")[1])
        with device_guard(dev_id, "gpu"):
            paddle.distributed.all_reduce(global_norm_var, group=self._group)

133
        global_norm_var = paddle.sqrt(global_norm_var)
134 135 136 137 138 139
        max_global_norm = layers.fill_constant(
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
        )

        clip_var = layers.elementwise_div(
            x=max_global_norm,
H
HongyuJia 已提交
140
            y=paddle.maximum(x=global_norm_var, y=max_global_norm),
141
        )
142 143 144
        clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

        for p, g in params_grads:
145
            if getattr(p, 'need_clip', True) is False or g is None:
146
                continue
147 148
            origin_state = g.stop_gradient
            g.stop_gradient = True
149
            if p.dtype == paddle.float16:
150
                g.scale_(clip_var_fp16)
151
            else:
152
                g.scale_(clip_var)
153
            g.stop_gradient = origin_state
154
            p._reset_grad_inplace_version(True)
155

156
        return params_grads
157 158 159 160 161 162 163 164

    def __getattr__(self, item):
        return getattr(self._clip, item)

    def __call__(self, params_grads):
        return self._dygraph_clip(params_grads)


165
@contextlib.contextmanager
166
def device_guard(dev_id=0, device="cpu"):
167 168 169 170 171 172 173 174 175
    origin_device = paddle.device.get_device()
    if device == "cpu":
        paddle.set_device(device)
    elif device == "gpu":
        paddle.set_device("gpu:{}".format(dev_id))
    try:
        yield
    finally:
        paddle.set_device(origin_device)
176 177 178


@dygraph_only
179
def ShardingScaler(scaler):
180 181 182 183 184 185
    def unscale_method(self, optimizer):
        if not self._enable:
            return
        param_grads = []
        param_grads_fp16 = []
        param_grads_fp32 = []
B
Baibaifan 已提交
186 187 188
        if hasattr(optimizer, "update_slice"):
            optimizer.update_slice()
            optimizer.update_scaler = True
189

190
        if getattr(optimizer._optim, '_param_groups', None) and isinstance(
191 192
            optimizer._optim._param_groups[0], dict
        ):
193

194
            for group in optimizer._optim._param_groups:
195 196 197
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
198
                        if param._grad_ivar().dtype in [
199 200
                            core.VarDesc.VarType.FP16,
                            paddle.float16,
201
                        ]:
202 203 204 205
                            param_grads_fp16.append(param._grad_ivar())
                        else:
                            param_grads_fp32.append(param._grad_ivar())
        else:
B
Baibaifan 已提交
206 207 208 209
            for param in optimizer._optim._parameter_list:
                if param.grad is not None:
                    param_grads.append(param.grad)
                    if param.grad.dtype in [
210 211
                        core.VarDesc.VarType.FP16,
                        paddle.float16,
B
Baibaifan 已提交
212 213 214 215 216
                    ]:
                        param_grads_fp16.append(param.grad)
                    else:
                        param_grads_fp32.append(param.grad)

217 218
        temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
        temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
219 220

        device = "cpu" if optimizer.offload else "gpu"
221 222 223
        dev_id = (
            0 if device == "cpu" else int(paddle.get_device().split(":")[1])
        )
224 225 226

        with device_guard(dev_id, device):
            if len(param_grads_fp16):
227 228 229 230 231 232
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp16,
                    self._scale,
                    param_grads_fp16,
                    temp_found_inf_fp16,
                )
233
            if len(param_grads_fp32):
234 235 236 237 238 239
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp32,
                    self._scale,
                    param_grads_fp32,
                    temp_found_inf_fp32,
                )
240 241 242 243

        self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
        is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

244 245 246 247 248
        paddle.distributed.all_reduce(
            is_found_inf,
            op=paddle.distributed.ReduceOp.MAX,
            group=optimizer.group,
        )
249 250 251 252
        self._found_inf = is_found_inf.numpy()[0]

    scaler._unscale = MethodType(unscale_method, scaler)
    return scaler