sharding_utils.py 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from enum import Enum
17 18
import numpy as np
from types import MethodType
19 20

import paddle
21
from paddle import _legacy_C_ops
22
from paddle.fluid import core
23 24 25 26
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph import base as imperative_base
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42


class Taskflow:
    """
    Task flows, one way linked list for task acquisition.
    """

    def __init__(self, task, callback):
        self.task = task
        self.callback = callback


class Type(Enum):
    """
    Type of trainable parameters
    """
43

44
    fp16 = paddle.float16
45
    bf16 = paddle.bfloat16
46 47 48
    fp32 = paddle.float32


49
class ShardingClipGrad:
50
    def __init__(self, clip, device, group):
51 52
        self._clip = clip
        self._device = device
53
        self._group = group
54 55 56

    @imperative_base.no_grad
    def _dygraph_clip(self, params_grads):
B
Baibaifan 已提交
57 58
        sum_square_fp32, sum_square_fp16 = [], []
        unslice_params_fp32, unslice_params_fp16 = [], []
59 60

        for p, g in params_grads:
B
Baibaifan 已提交
61
            p_slice = True  # using for slice parameter in sharding stage3
62 63
            if g is None or getattr(p, 'need_clip', True) is False:
                continue
B
Baibaifan 已提交
64 65
            if hasattr(p, "unslice"):
                p_slice = False
66 67 68 69

            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.get_tensor_from_selected_rows(
70 71
                    layers.merge_selected_rows(g)
                )
72 73 74 75
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)

            if p.dtype == paddle.float16:
76 77 78 79
                if p_slice:
                    sum_square_fp16.append(sum_square)
                else:
                    unslice_params_fp16.append(sum_square)
80
            elif p.dtype == paddle.float32:
81 82 83 84
                if p_slice:
                    sum_square_fp32.append(sum_square)
                else:
                    unslice_params_fp32.append(sum_square)
85 86 87

        # global norm of non-distributed FP16 params_and_grads
        if len(sum_square_fp16) == 0:
88
            global_norm_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
89 90 91
        else:
            global_norm_fp16 = layers.concat(sum_square_fp16)
            global_norm_fp16 = layers.reduce_sum(global_norm_fp16)
92 93 94
            global_norm_fp16 = paddle.cast(
                global_norm_fp16, dtype=paddle.float32
            )
95

B
Baibaifan 已提交
96
        # global norm of non-distributed FP16 params_and_grads for unslice parameter
B
Baibaifan 已提交
97
        if len(unslice_params_fp16) == 0:
98
            global_unslice_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
B
Baibaifan 已提交
99 100 101
        else:
            global_unslice_fp16 = layers.concat(unslice_params_fp16)
            global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16)
102 103 104
            global_unslice_fp16 = paddle.cast(
                global_unslice_fp16, dtype=paddle.float32
            )
B
Baibaifan 已提交
105

106
        # global norm of non-distributed FP32 params_and_grads
107 108 109 110 111
        global_norm_fp32 = (
            layers.concat(sum_square_fp32)
            if len(sum_square_fp32) != 0
            else paddle.to_tensor([0.0], dtype=paddle.float32)
        )
112 113
        global_norm_fp32 = layers.reduce_sum(global_norm_fp32)

B
Baibaifan 已提交
114
        # global norm of non-distributed FP32 params_and_grads for unslice parameter
115 116 117 118 119
        global_unslice_fp32 = (
            layers.concat(unslice_params_fp32)
            if len(unslice_params_fp32) != 0
            else paddle.to_tensor([0.0], dtype=paddle.float32)
        )
B
Baibaifan 已提交
120 121 122
        global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
        global_unslice_var = global_unslice_fp16 + global_unslice_fp32

123 124 125 126 127
        global_norm_var = (
            global_norm_fp16
            + global_norm_fp32
            + 1.0 / self._group.nranks * global_unslice_var
        )
128 129 130 131 132 133 134

        # add all reduce to get global norm of distributed params_and_grads
        dev_id = int(self._device.split(":")[1])
        with device_guard(dev_id, "gpu"):
            paddle.distributed.all_reduce(global_norm_var, group=self._group)

        global_norm_var = layers.sqrt(global_norm_var)
135 136 137 138 139 140 141 142
        max_global_norm = layers.fill_constant(
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
        )

        clip_var = layers.elementwise_div(
            x=max_global_norm,
            y=layers.elementwise_max(x=global_norm_var, y=max_global_norm),
        )
143 144 145
        clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

        for p, g in params_grads:
146
            if getattr(p, 'need_clip', True) is False or g is None:
147
                continue
148 149
            origin_state = g.stop_gradient
            g.stop_gradient = True
150
            if p.dtype == paddle.float16:
151
                g.scale_(clip_var_fp16)
152
            else:
153
                g.scale_(clip_var)
154
            g.stop_gradient = origin_state
155
            p._reset_grad_inplace_version(True)
156

157
        return params_grads
158 159 160 161 162 163 164 165

    def __getattr__(self, item):
        return getattr(self._clip, item)

    def __call__(self, params_grads):
        return self._dygraph_clip(params_grads)


166
@contextlib.contextmanager
167
def device_guard(dev_id=0, device="cpu"):
168 169 170 171 172 173 174 175 176
    origin_device = paddle.device.get_device()
    if device == "cpu":
        paddle.set_device(device)
    elif device == "gpu":
        paddle.set_device("gpu:{}".format(dev_id))
    try:
        yield
    finally:
        paddle.set_device(origin_device)
177 178 179


@dygraph_only
180
def ShardingScaler(scaler):
181 182 183 184 185 186
    def unscale_method(self, optimizer):
        if not self._enable:
            return
        param_grads = []
        param_grads_fp16 = []
        param_grads_fp32 = []
B
Baibaifan 已提交
187 188 189
        if hasattr(optimizer, "update_slice"):
            optimizer.update_slice()
            optimizer.update_scaler = True
190

191
        if getattr(optimizer._optim, '_param_groups', None) and isinstance(
192 193
            optimizer._optim._param_groups[0], dict
        ):
194

195
            for group in optimizer._optim._param_groups:
196 197 198
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
199
                        if param._grad_ivar().dtype in [
200 201
                            core.VarDesc.VarType.FP16,
                            paddle.float16,
202
                        ]:
203 204 205 206
                            param_grads_fp16.append(param._grad_ivar())
                        else:
                            param_grads_fp32.append(param._grad_ivar())
        else:
B
Baibaifan 已提交
207 208 209 210
            for param in optimizer._optim._parameter_list:
                if param.grad is not None:
                    param_grads.append(param.grad)
                    if param.grad.dtype in [
211 212
                        core.VarDesc.VarType.FP16,
                        paddle.float16,
B
Baibaifan 已提交
213 214 215 216 217
                    ]:
                        param_grads_fp16.append(param.grad)
                    else:
                        param_grads_fp32.append(param.grad)

218 219
        temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
        temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
220 221

        device = "cpu" if optimizer.offload else "gpu"
222 223 224
        dev_id = (
            0 if device == "cpu" else int(paddle.get_device().split(":")[1])
        )
225 226 227

        with device_guard(dev_id, device):
            if len(param_grads_fp16):
228 229 230 231 232 233
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp16,
                    self._scale,
                    param_grads_fp16,
                    temp_found_inf_fp16,
                )
234
            if len(param_grads_fp32):
235 236 237 238 239 240
                _legacy_C_ops.check_finite_and_unscale(
                    param_grads_fp32,
                    self._scale,
                    param_grads_fp32,
                    temp_found_inf_fp32,
                )
241 242 243 244

        self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
        is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

245 246 247 248 249
        paddle.distributed.all_reduce(
            is_found_inf,
            op=paddle.distributed.ReduceOp.MAX,
            group=optimizer.group,
        )
250 251 252 253
        self._found_inf = is_found_inf.numpy()[0]

    scaler._unscale = MethodType(unscale_method, scaler)
    return scaler