sharding_utils.py 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
from enum import Enum
17 18
import numpy as np
from types import MethodType
19 20

import paddle
21
from paddle import _legacy_C_ops
22
from paddle.fluid import core
23 24 25 26
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph import base as imperative_base
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


class Taskflow:
    """
    Task flows, one way linked list for task acquisition.
    """

    def __init__(self, task, callback):
        self.task = task
        self.callback = callback


class Type(Enum):
    """
    Type of trainable parameters
    """
    fp16 = paddle.float16
    fp32 = paddle.float32


47
class ShardingClipGrad:
48

49
    def __init__(self, clip, device, group):
50 51
        self._clip = clip
        self._device = device
52
        self._group = group
53 54 55

    @imperative_base.no_grad
    def _dygraph_clip(self, params_grads):
B
Baibaifan 已提交
56 57
        sum_square_fp32, sum_square_fp16 = [], []
        unslice_params_fp32, unslice_params_fp16 = [], []
58 59

        for p, g in params_grads:
B
Baibaifan 已提交
60
            p_slice = True  # using for slice parameter in sharding stage3
61 62
            if g is None or getattr(p, 'need_clip', True) is False:
                continue
B
Baibaifan 已提交
63 64
            if hasattr(p, "unslice"):
                p_slice = False
65 66 67 68 69 70 71 72 73

            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.get_tensor_from_selected_rows(
                    layers.merge_selected_rows(g))
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)

            if p.dtype == paddle.float16:
B
Baibaifan 已提交
74 75
                if p_slice: sum_square_fp16.append(sum_square)
                else: unslice_params_fp16.append(sum_square)
76
            elif p.dtype == paddle.float32:
B
Baibaifan 已提交
77 78
                if p_slice: sum_square_fp32.append(sum_square)
                else: unslice_params_fp32.append(sum_square)
79 80 81 82 83 84 85

        # global norm of non-distributed FP16 params_and_grads
        if len(sum_square_fp16) == 0:
            global_norm_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_norm_fp16 = layers.concat(sum_square_fp16)
            global_norm_fp16 = layers.reduce_sum(global_norm_fp16)
86 87
            global_norm_fp16 = paddle.cast(global_norm_fp16,
                                           dtype=paddle.float32)
88

B
Baibaifan 已提交
89
        # global norm of non-distributed FP16 params_and_grads for unslice parameter
B
Baibaifan 已提交
90 91 92 93 94
        if len(unslice_params_fp16) == 0:
            global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_unslice_fp16 = layers.concat(unslice_params_fp16)
            global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16)
95 96
            global_unslice_fp16 = paddle.cast(global_unslice_fp16,
                                              dtype=paddle.float32)
B
Baibaifan 已提交
97

98
        # global norm of non-distributed FP32 params_and_grads
99 100
        global_norm_fp32 = layers.concat(
            sum_square_fp32) if len(sum_square_fp32) != 0 else paddle.to_tensor(
101 102 103
                [0.], dtype=paddle.float32)
        global_norm_fp32 = layers.reduce_sum(global_norm_fp32)

B
Baibaifan 已提交
104
        # global norm of non-distributed FP32 params_and_grads for unslice parameter
B
Baibaifan 已提交
105 106 107 108 109 110
        global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
            unslice_params_fp32) != 0 else paddle.to_tensor(
                [0.], dtype=paddle.float32)
        global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
        global_unslice_var = global_unslice_fp16 + global_unslice_fp32

B
Baibaifan 已提交
111
        global_norm_var = global_norm_fp16 + global_norm_fp32 + 1.0 / self._group.nranks * global_unslice_var
112 113 114 115 116 117 118

        # add all reduce to get global norm of distributed params_and_grads
        dev_id = int(self._device.split(":")[1])
        with device_guard(dev_id, "gpu"):
            paddle.distributed.all_reduce(global_norm_var, group=self._group)

        global_norm_var = layers.sqrt(global_norm_var)
119 120 121 122 123 124 125 126
        max_global_norm = layers.fill_constant(shape=[1],
                                               dtype=global_norm_var.dtype,
                                               value=self.clip_norm)

        clip_var = layers.elementwise_div(x=max_global_norm,
                                          y=layers.elementwise_max(
                                              x=global_norm_var,
                                              y=max_global_norm))
127 128 129
        clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

        for p, g in params_grads:
130
            if getattr(p, 'need_clip', True) is False or g is None:
131
                continue
132 133
            origin_state = g.stop_gradient
            g.stop_gradient = True
134
            if p.dtype == paddle.float16:
135
                g.scale_(clip_var_fp16)
136
            else:
137
                g.scale_(clip_var)
138
            g.stop_gradient = origin_state
139
            p._reset_grad_inplace_version(True)
140

141
        return params_grads
142 143 144 145 146 147 148 149

    def __getattr__(self, item):
        return getattr(self._clip, item)

    def __call__(self, params_grads):
        return self._dygraph_clip(params_grads)


150
@contextlib.contextmanager
151
def device_guard(dev_id=0, device="cpu"):
152 153 154 155 156 157 158 159 160
    origin_device = paddle.device.get_device()
    if device == "cpu":
        paddle.set_device(device)
    elif device == "gpu":
        paddle.set_device("gpu:{}".format(dev_id))
    try:
        yield
    finally:
        paddle.set_device(origin_device)
161 162 163


@dygraph_only
164
def ShardingScaler(scaler):
165

166 167 168 169 170 171
    def unscale_method(self, optimizer):
        if not self._enable:
            return
        param_grads = []
        param_grads_fp16 = []
        param_grads_fp32 = []
B
Baibaifan 已提交
172 173 174
        if hasattr(optimizer, "update_slice"):
            optimizer.update_slice()
            optimizer.update_scaler = True
175

176 177
        if getattr(optimizer._optim, '_param_groups', None) and isinstance(
                optimizer._optim._param_groups[0], dict):
178

179
            for group in optimizer._optim._param_groups:
180 181 182
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
183 184 185
                        if param._grad_ivar().dtype in [
                                core.VarDesc.VarType.FP16, paddle.float16
                        ]:
186 187 188 189
                            param_grads_fp16.append(param._grad_ivar())
                        else:
                            param_grads_fp32.append(param._grad_ivar())
        else:
B
Baibaifan 已提交
190 191 192 193 194 195 196 197 198 199
            for param in optimizer._optim._parameter_list:
                if param.grad is not None:
                    param_grads.append(param.grad)
                    if param.grad.dtype in [
                            core.VarDesc.VarType.FP16, paddle.float16
                    ]:
                        param_grads_fp16.append(param.grad)
                    else:
                        param_grads_fp32.append(param.grad)

200 201
        temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
        temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
202 203

        device = "cpu" if optimizer.offload else "gpu"
204 205
        dev_id = 0 if device == "cpu" else int(
            paddle.get_device().split(":")[1])
206 207 208

        with device_guard(dev_id, device):
            if len(param_grads_fp16):
209 210 211 212
                _legacy_C_ops.check_finite_and_unscale(param_grads_fp16,
                                                       self._scale,
                                                       param_grads_fp16,
                                                       temp_found_inf_fp16)
213
            if len(param_grads_fp32):
214 215 216 217
                _legacy_C_ops.check_finite_and_unscale(param_grads_fp32,
                                                       self._scale,
                                                       param_grads_fp32,
                                                       temp_found_inf_fp32)
218 219 220 221

        self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
        is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

222 223 224
        paddle.distributed.all_reduce(is_found_inf,
                                      op=paddle.distributed.ReduceOp.MAX,
                                      group=optimizer.group)
225 226 227 228
        self._found_inf = is_found_inf.numpy()[0]

    scaler._unscale = MethodType(unscale_method, scaler)
    return scaler