sharding_utils.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import contextlib
from collections import abc
from enum import Enum
from math import inf
20 21
import numpy as np
from types import MethodType
22 23

import paddle
24
from paddle import _C_ops, _legacy_C_ops
25
from paddle.fluid import core
26 27 28 29
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph import base as imperative_base
30
from paddle.distributed.collective import _get_global_group
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47


class Taskflow:
    """
    Task flows, one way linked list for task acquisition.
    """

    def __init__(self, task, callback):
        self.task = task
        self.callback = callback


class Type(Enum):
    """
    Type of trainable parameters
    """
    fp16 = paddle.float16
48
    bf16 = paddle.bfloat16
49 50 51
    fp32 = paddle.float32


52
class ShardingClipGrad:
53

54
    def __init__(self, clip, device, group):
55 56
        self._clip = clip
        self._device = device
57
        self._group = group
58 59 60

    @imperative_base.no_grad
    def _dygraph_clip(self, params_grads):
B
Baibaifan 已提交
61 62
        sum_square_fp32, sum_square_fp16 = [], []
        unslice_params_fp32, unslice_params_fp16 = [], []
63 64

        for p, g in params_grads:
B
Baibaifan 已提交
65
            p_slice = True  # using for slice parameter in sharding stage3
66 67
            if g is None or getattr(p, 'need_clip', True) is False:
                continue
B
Baibaifan 已提交
68 69
            if hasattr(p, "unslice"):
                p_slice = False
70 71 72 73 74 75 76 77 78

            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.get_tensor_from_selected_rows(
                    layers.merge_selected_rows(g))
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)

            if p.dtype == paddle.float16:
B
Baibaifan 已提交
79 80
                if p_slice: sum_square_fp16.append(sum_square)
                else: unslice_params_fp16.append(sum_square)
81
            elif p.dtype == paddle.float32:
B
Baibaifan 已提交
82 83
                if p_slice: sum_square_fp32.append(sum_square)
                else: unslice_params_fp32.append(sum_square)
84 85 86 87 88 89 90

        # global norm of non-distributed FP16 params_and_grads
        if len(sum_square_fp16) == 0:
            global_norm_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_norm_fp16 = layers.concat(sum_square_fp16)
            global_norm_fp16 = layers.reduce_sum(global_norm_fp16)
91 92
            global_norm_fp16 = paddle.cast(global_norm_fp16,
                                           dtype=paddle.float32)
93

B
Baibaifan 已提交
94
        # global norm of non-distributed FP16 params_and_grads for unslice parameter
B
Baibaifan 已提交
95 96 97 98 99
        if len(unslice_params_fp16) == 0:
            global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_unslice_fp16 = layers.concat(unslice_params_fp16)
            global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16)
100 101
            global_unslice_fp16 = paddle.cast(global_unslice_fp16,
                                              dtype=paddle.float32)
B
Baibaifan 已提交
102

103
        # global norm of non-distributed FP32 params_and_grads
104 105
        global_norm_fp32 = layers.concat(
            sum_square_fp32) if len(sum_square_fp32) != 0 else paddle.to_tensor(
106 107 108
                [0.], dtype=paddle.float32)
        global_norm_fp32 = layers.reduce_sum(global_norm_fp32)

B
Baibaifan 已提交
109
        # global norm of non-distributed FP32 params_and_grads for unslice parameter
B
Baibaifan 已提交
110 111 112 113 114 115
        global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
            unslice_params_fp32) != 0 else paddle.to_tensor(
                [0.], dtype=paddle.float32)
        global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
        global_unslice_var = global_unslice_fp16 + global_unslice_fp32

B
Baibaifan 已提交
116
        global_norm_var = global_norm_fp16 + global_norm_fp32 + 1.0 / self._group.nranks * global_unslice_var
117 118 119 120 121 122 123

        # add all reduce to get global norm of distributed params_and_grads
        dev_id = int(self._device.split(":")[1])
        with device_guard(dev_id, "gpu"):
            paddle.distributed.all_reduce(global_norm_var, group=self._group)

        global_norm_var = layers.sqrt(global_norm_var)
124 125 126 127 128 129 130 131
        max_global_norm = layers.fill_constant(shape=[1],
                                               dtype=global_norm_var.dtype,
                                               value=self.clip_norm)

        clip_var = layers.elementwise_div(x=max_global_norm,
                                          y=layers.elementwise_max(
                                              x=global_norm_var,
                                              y=max_global_norm))
132 133 134
        clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

        for p, g in params_grads:
135
            if getattr(p, 'need_clip', True) is False or g is None:
136
                continue
137 138
            origin_state = g.stop_gradient
            g.stop_gradient = True
139
            if p.dtype == paddle.float16:
140
                g.scale_(clip_var_fp16)
141
            else:
142
                g.scale_(clip_var)
143
            g.stop_gradient = origin_state
144
            p._reset_grad_inplace_version(True)
145

146
        return params_grads
147 148 149 150 151 152 153 154

    def __getattr__(self, item):
        return getattr(self._clip, item)

    def __call__(self, params_grads):
        return self._dygraph_clip(params_grads)


155
@contextlib.contextmanager
156
def device_guard(dev_id=0, device="cpu"):
157 158 159 160 161 162 163 164 165
    origin_device = paddle.device.get_device()
    if device == "cpu":
        paddle.set_device(device)
    elif device == "gpu":
        paddle.set_device("gpu:{}".format(dev_id))
    try:
        yield
    finally:
        paddle.set_device(origin_device)
166 167 168


@dygraph_only
169
def ShardingScaler(scaler):
170

171 172 173 174 175 176
    def unscale_method(self, optimizer):
        if not self._enable:
            return
        param_grads = []
        param_grads_fp16 = []
        param_grads_fp32 = []
B
Baibaifan 已提交
177 178 179
        if hasattr(optimizer, "update_slice"):
            optimizer.update_slice()
            optimizer.update_scaler = True
180

181 182
        if getattr(optimizer._optim, '_param_groups', None) and isinstance(
                optimizer._optim._param_groups[0], dict):
183

184
            for group in optimizer._optim._param_groups:
185 186 187
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
188 189 190
                        if param._grad_ivar().dtype in [
                                core.VarDesc.VarType.FP16, paddle.float16
                        ]:
191 192 193 194
                            param_grads_fp16.append(param._grad_ivar())
                        else:
                            param_grads_fp32.append(param._grad_ivar())
        else:
B
Baibaifan 已提交
195 196 197 198 199 200 201 202 203 204
            for param in optimizer._optim._parameter_list:
                if param.grad is not None:
                    param_grads.append(param.grad)
                    if param.grad.dtype in [
                            core.VarDesc.VarType.FP16, paddle.float16
                    ]:
                        param_grads_fp16.append(param.grad)
                    else:
                        param_grads_fp32.append(param.grad)

205 206
        temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
        temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
207 208

        device = "cpu" if optimizer.offload else "gpu"
209 210
        dev_id = 0 if device == "cpu" else int(
            paddle.get_device().split(":")[1])
211 212 213

        with device_guard(dev_id, device):
            if len(param_grads_fp16):
214 215 216 217
                _legacy_C_ops.check_finite_and_unscale(param_grads_fp16,
                                                       self._scale,
                                                       param_grads_fp16,
                                                       temp_found_inf_fp16)
218
            if len(param_grads_fp32):
219 220 221 222
                _legacy_C_ops.check_finite_and_unscale(param_grads_fp32,
                                                       self._scale,
                                                       param_grads_fp32,
                                                       temp_found_inf_fp32)
223 224 225 226

        self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
        is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

227 228 229
        paddle.distributed.all_reduce(is_found_inf,
                                      op=paddle.distributed.ReduceOp.MAX,
                                      group=optimizer.group)
230 231 232 233
        self._found_inf = is_found_inf.numpy()[0]

    scaler._unscale = MethodType(unscale_method, scaler)
    return scaler