sharding_utils.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import contextlib
from collections import abc
from enum import Enum
from math import inf
20 21
import numpy as np
from types import MethodType
22 23

import paddle
24
from paddle import _C_ops
25
from paddle.fluid import core
26 27 28 29
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph import base as imperative_base
30
from paddle.distributed.collective import _get_global_group
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50


class Taskflow:
    """
    Task flows, one way linked list for task acquisition.
    """

    def __init__(self, task, callback):
        self.task = task
        self.callback = callback


class Type(Enum):
    """
    Type of trainable parameters
    """
    fp16 = paddle.float16
    fp32 = paddle.float32


51
class ShardingClipGrad:
52
    def __init__(self, clip, device, group):
53 54
        self._clip = clip
        self._device = device
55
        self._group = group
56 57 58

    @imperative_base.no_grad
    def _dygraph_clip(self, params_grads):
B
Baibaifan 已提交
59 60
        sum_square_fp32, sum_square_fp16 = [], []
        unslice_params_fp32, unslice_params_fp16 = [], []
61 62

        for p, g in params_grads:
B
Baibaifan 已提交
63
            p_slice = True  # using for slice parameter in sharding stage3
64 65
            if g is None or getattr(p, 'need_clip', True) is False:
                continue
B
Baibaifan 已提交
66 67
            if hasattr(p, "unslice"):
                p_slice = False
68 69 70 71 72 73 74 75 76

            merge_grad = g
            if g.type == core.VarDesc.VarType.SELECTED_ROWS:
                merge_grad = layers.get_tensor_from_selected_rows(
                    layers.merge_selected_rows(g))
            square = layers.square(merge_grad)
            sum_square = layers.reduce_sum(square)

            if p.dtype == paddle.float16:
B
Baibaifan 已提交
77 78
                if p_slice: sum_square_fp16.append(sum_square)
                else: unslice_params_fp16.append(sum_square)
79
            elif p.dtype == paddle.float32:
B
Baibaifan 已提交
80 81
                if p_slice: sum_square_fp32.append(sum_square)
                else: unslice_params_fp32.append(sum_square)
82 83 84 85 86 87 88 89 90 91

        # global norm of non-distributed FP16 params_and_grads
        if len(sum_square_fp16) == 0:
            global_norm_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_norm_fp16 = layers.concat(sum_square_fp16)
            global_norm_fp16 = layers.reduce_sum(global_norm_fp16)
            global_norm_fp16 = paddle.cast(
                global_norm_fp16, dtype=paddle.float32)

B
Baibaifan 已提交
92 93 94 95 96 97 98 99 100
        # global norm of non-distributed FP16 params_and_grads for slice parameter
        if len(unslice_params_fp16) == 0:
            global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
        else:
            global_unslice_fp16 = layers.concat(unslice_params_fp16)
            global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16)
            global_unslice_fp16 = paddle.cast(
                global_unslice_fp16, dtype=paddle.float32)

101 102 103 104 105 106
        # global norm of non-distributed FP32 params_and_grads
        global_norm_fp32 = layers.concat(sum_square_fp32) if len(
            sum_square_fp32) != 0 else paddle.to_tensor(
                [0.], dtype=paddle.float32)
        global_norm_fp32 = layers.reduce_sum(global_norm_fp32)

B
Baibaifan 已提交
107 108 109 110 111 112 113
        # global norm of non-distributed FP32 params_and_grads for slice parameter
        global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
            unslice_params_fp32) != 0 else paddle.to_tensor(
                [0.], dtype=paddle.float32)
        global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
        global_unslice_var = global_unslice_fp16 + global_unslice_fp32

114 115 116 117 118 119 120
        global_norm_var = global_norm_fp16 + global_norm_fp32

        # add all reduce to get global norm of distributed params_and_grads
        dev_id = int(self._device.split(":")[1])
        with device_guard(dev_id, "gpu"):
            paddle.distributed.all_reduce(global_norm_var, group=self._group)

B
Baibaifan 已提交
121
        global_norm_var += global_unslice_var
122 123 124 125 126 127 128 129 130 131 132
        global_norm_var = layers.sqrt(global_norm_var)
        max_global_norm = layers.fill_constant(
            shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)

        clip_var = layers.elementwise_div(
            x=max_global_norm,
            y=layers.elementwise_max(
                x=global_norm_var, y=max_global_norm))
        clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

        for p, g in params_grads:
133
            if getattr(p, 'need_clip', True) is False or g is None:
134
                continue
135 136
            origin_state = g.stop_gradient
            g.stop_gradient = True
137
            if p.dtype == paddle.float16:
138
                g.scale_(clip_var_fp16)
139
            else:
140
                g.scale_(clip_var)
141
            g.stop_gradient = origin_state
142
            p._reset_grad_inplace_version(True)
143

144
        return params_grads
145 146 147 148 149 150 151 152

    def __getattr__(self, item):
        return getattr(self._clip, item)

    def __call__(self, params_grads):
        return self._dygraph_clip(params_grads)


153
@contextlib.contextmanager
154
def device_guard(dev_id=0, device="cpu"):
155 156 157 158 159 160 161 162 163
    origin_device = paddle.device.get_device()
    if device == "cpu":
        paddle.set_device(device)
    elif device == "gpu":
        paddle.set_device("gpu:{}".format(dev_id))
    try:
        yield
    finally:
        paddle.set_device(origin_device)
164 165 166


@dygraph_only
167
def ShardingScaler(scaler):
168 169 170 171 172 173
    def unscale_method(self, optimizer):
        if not self._enable:
            return
        param_grads = []
        param_grads_fp16 = []
        param_grads_fp32 = []
B
Baibaifan 已提交
174 175 176
        if hasattr(optimizer, "update_slice"):
            optimizer.update_slice()
            optimizer.update_scaler = True
177

178 179
        if getattr(optimizer._optim, '_param_groups', None) and isinstance(
                optimizer._optim._param_groups[0], dict):
180

181
            for group in optimizer._optim._param_groups:
182 183 184 185
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
                        if param._grad_ivar(
B
Baibaifan 已提交
186
                        ).dtype in [core.VarDesc.VarType.FP16, paddle.float16]:
187 188 189 190
                            param_grads_fp16.append(param._grad_ivar())
                        else:
                            param_grads_fp32.append(param._grad_ivar())
        else:
B
Baibaifan 已提交
191 192 193 194 195 196 197 198 199 200
            for param in optimizer._optim._parameter_list:
                if param.grad is not None:
                    param_grads.append(param.grad)
                    if param.grad.dtype in [
                            core.VarDesc.VarType.FP16, paddle.float16
                    ]:
                        param_grads_fp16.append(param.grad)
                    else:
                        param_grads_fp32.append(param.grad)

201 202
        temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
        temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
203 204 205 206 207 208 209 210 211 212 213 214 215 216

        device = "cpu" if optimizer.offload else "gpu"
        dev_id = 0 if device == "cpu" else int(paddle.get_device().split(":")[
            1])

        with device_guard(dev_id, device):
            if len(param_grads_fp16):
                _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
                                                param_grads_fp16,
                                                temp_found_inf_fp16)
            if len(param_grads_fp32):
                _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
                                                param_grads_fp32,
                                                temp_found_inf_fp32)
217 218 219 220 221 222 223

        self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
        is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

        paddle.distributed.all_reduce(
            is_found_inf,
            op=paddle.distributed.ReduceOp.MAX,
224
            group=optimizer.group)
225 226 227 228
        self._found_inf = is_found_inf.numpy()[0]

    scaler._unscale = MethodType(unscale_method, scaler)
    return scaler