hybrid_parallel_util.py 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
from paddle import framework
17 18

# (TODO: GhostScreaming) It will be removed later.
19
from paddle.fluid import core
20 21
from paddle.framework import (
    _in_legacy_dygraph,
22 23
    _split_tensors,
    build_groups,
24
    in_dygraph_mode,
25
    sync_params_buffers,
26
)
27

28
from .log_util import logger
29

30 31
__all__ = []

32

33
def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None):
34 35 36 37 38 39 40
    grad_var_set = set()
    grad_vars = []
    sparse_grad_vars = []

    for param in parameters:
        if param.trainable and (param._grad_ivar() is not None):
            g_var = param._grad_ivar()
41 42
            assert (
                not g_var._is_sparse()
43 44 45 46 47
            ), "Now, it doesn't support sparse parameters"
            grad_vars.append(g_var)
            assert g_var not in grad_var_set
            grad_var_set.add(g_var)

48
    coalesced_grads_and_vars = build_groups(grad_vars, bucket_size)
49

50 51 52 53 54
    nranks = (
        paddle.distributed.get_world_size()
        if comm_group is None
        else comm_group.nranks
    )
55 56 57 58

    scale = nranks if scale is None else 1.0 / scale
    scale = None if scale == 1.0 else scale

59 60
    for coalesced_grad, _, _ in coalesced_grads_and_vars:
        # need to div nranks
61 62 63 64 65 66 67 68
        if scale is not None:
            div_factor = paddle.to_tensor(scale, dtype=coalesced_grad.dtype)
            paddle.fluid.framework._dygraph_tracer().trace_op(
                type="elementwise_div",
                inputs={'X': coalesced_grad, 'Y': div_factor},
                outputs={'Out': coalesced_grad},
                attrs={'axis': -1},
            )
69
        paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
70

71 72 73
    _split_tensors(coalesced_grads_and_vars)


74 75 76
def _apply_collective_grads_eager(
    parameters, comm_group, bucket_size, scale=None
):
77 78 79 80 81 82
    grad_var_set = set()
    grad_vars = []

    for param in parameters:
        if param.trainable and (param._grad_ivar() is not None):
            g_var = param._grad_ivar()
83 84
            assert (
                not g_var.is_sparse()
85 86 87 88 89
            ), "Now, it doesn't support sparse parameters"
            grad_vars.append(g_var)
            assert g_var not in grad_var_set
            grad_var_set.add(g_var)

90
    coalesced_grads_and_vars = build_groups(grad_vars, bucket_size)
91

92 93 94 95 96
    nranks = (
        paddle.distributed.get_world_size()
        if comm_group is None
        else comm_group.nranks
    )
97 98 99 100

    scale = 1.0 / nranks if scale is None else scale
    scale = None if scale == 1.0 else scale

101
    for coalesced_grad, _, _ in coalesced_grads_and_vars:
102
        # need to div nranks
103 104
        if scale is not None:
            coalesced_grad.scale_(scale)
105 106 107 108 109
        paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

    _split_tensors(coalesced_grads_and_vars)


110
def _broadcast_data_help(data, shape, dtype, hcg):
111 112
    model_parallel_group = hcg.get_model_parallel_group()
    src_rank = hcg.get_model_parallel_group_src_rank()
113 114 115
    mp_rank = hcg.get_model_parallel_rank()

    shape_gpu = paddle.to_tensor(shape, dtype="int32")
116 117 118
    paddle.distributed.broadcast(
        shape_gpu, src=src_rank, group=model_parallel_group, sync_op=True
    )
119 120 121 122 123 124

    if mp_rank != 0:
        input_data = paddle.zeros(shape_gpu, dtype=dtype)
    else:
        input_data = data

125 126 127
    paddle.distributed.broadcast(
        input_data, src=src_rank, group=model_parallel_group, sync_op=True
    )
128

129 130 131 132 133 134 135
    if mp_rank != 0:
        if in_dygraph_mode():
            data._clear_data()
            input_data._share_buffer_to(data)
        else:
            data.value().get_tensor()._clear()
            data.value().get_tensor()._share_data_with(
136 137
                input_data.value().get_tensor()
            )
138

139 140

def broadcast_input_data(hcg, *inputs, **kwargs):
141
    cur_device = paddle.get_device()
142
    for v in inputs:
143
        if isinstance(v, (core.VarBase, core.eager.Tensor)):
144
            with framework.no_grad():
145 146 147 148 149
                if (
                    "gpu" in cur_device
                    and in_dygraph_mode()
                    and not v.place.is_gpu_place()
                ):
150 151 152
                    v_gpu = v.cuda(int(cur_device.split(":")[1]))
                    v._clear_data()
                    v_gpu._share_buffer_to(v)
153
                _broadcast_data_help(v, v.shape, v.dtype, hcg)
154
        else:
155
            logger.error("it doesn't support data type {}".format(type(v)))
156 157

    for k, v in kwargs.items():
158
        if isinstance(v, (core.VarBase, core.eager.Tensor)):
159
            with framework.no_grad():
160 161 162 163 164
                if (
                    "gpu" in cur_device
                    and in_dygraph_mode()
                    and not v.place.is_gpu_place()
                ):
165 166 167
                    v_gpu = v.cuda(int(cur_device.split(":")[1]))
                    v._clear_data()
                    v_gpu._share_buffer_to(v)
168
                _broadcast_data_help(v, v.shape, v.dtype, hcg)
169 170
            kwargs[k] = v
        else:
171
            logger.error("it doesn't support data type {}".format(type(v)))
172 173 174 175 176 177
    return inputs, kwargs


def broadcast_mp_parameters(model, hcg):
    model_parallel_group = hcg.get_model_parallel_group()
    src_rank = hcg.get_model_parallel_group_src_rank()
178 179 180
    sync_params_buffers(
        model, model_parallel_group, src_rank, is_model_parallel=True
    )
181 182 183 184 185


def broadcast_dp_parameters(model, hcg):
    data_parallel_group = hcg.get_data_parallel_group()
    src_rank = hcg.get_data_parallel_group_src_rank()
186 187 188
    sync_params_buffers(
        model, data_parallel_group, src_rank, is_model_parallel=False
    )
189 190


191 192 193
def fused_allreduce_gradients_with_group(
    parameter_list, group, bucket_size=128 * 1024 * 1024, scale=None
):
194 195 196 197 198
    apply_func = (
        _apply_collective_grads_eager
        if in_dygraph_mode()
        else _apply_collective_grads
    )
H
Haohongxiang 已提交
199
    with framework.no_grad():
S
sneaxiy 已提交
200
        apply_func(parameter_list, group, bucket_size, scale)
201 202 203 204 205 206


def fused_allreduce_gradients(parameter_list, hcg):
    data_parallel_group = None if hcg is None else hcg.get_data_parallel_group()
    logger.debug("dp start fuse allreduce gradients")
    fused_allreduce_gradients_with_group(parameter_list, data_parallel_group)
J
JZ-LIANG 已提交
207 208 209 210


def sharding_reduce_gradients(parameter_list, hcg):
    # TODO allreduce --> reduce
211
    # TODO merge grad / nrank with dp
J
JZ-LIANG 已提交
212 213 214 215 216 217
    logger.debug("sharding start gradients sync")
    with framework.no_grad():

        sharding_nrank = hcg.get_sharding_parallel_group().nranks
        for param in parameter_list:
            if param.trainable and (param._grad_ivar() is not None):
218 219 220 221 222
                if in_dygraph_mode():
                    param.grad.scale_(1.0 / sharding_nrank)
                    paddle.distributed.all_reduce(
                        param.grad,
                        group=hcg.get_sharding_parallel_group(),
223 224
                        sync_op=True,
                    )
225 226 227

                elif _in_legacy_dygraph():
                    g_var = param._grad_ivar()
228
                    # need use trace_op to allreduce
229 230 231 232 233 234 235 236
                    # paddle.distributed.all_reduce(
                    #     g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
                    paddle.fluid.framework._dygraph_tracer().trace_op(
                        type="c_allreduce_sum",
                        inputs={'X': g_var},
                        outputs={'Out': g_var},
                        attrs={
                            'ring_id': hcg.get_sharding_parallel_group().id,
237 238 239
                            'use_calc_stream': True,
                        },
                    )
240 241

                    # grad / sharding_rank
242 243 244
                    div_factor = paddle.to_tensor(
                        sharding_nrank, dtype=g_var.dtype
                    )
245 246
                    paddle.fluid.framework._dygraph_tracer().trace_op(
                        type="elementwise_div",
247
                        inputs={'X': g_var, 'Y': div_factor},
248
                        outputs={'Out': g_var},
249 250
                        attrs={'axis': -1},
                    )
J
JZ-LIANG 已提交
251 252 253 254 255 256 257


def broadcast_sharding_parameters(model, hcg):
    # TODO TO save memory, use un-fused broadcast to avoid potentional OOM
    logger.debug("sharding start init parameters sync")
    sharding_parallel_group = hcg.get_sharding_parallel_group()
    src_rank = hcg.get_sharding_parallel_group_src_rank()
258 259 260
    sync_params_buffers(
        model, sharding_parallel_group, src_rank, is_model_parallel=False
    )