tensor_fusion_helper.py 12.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
15
import os
16 17 18 19 20
from collections import OrderedDict

import numpy as np

import paddle
21
from paddle.framework import base as imperative_base
22 23
from paddle.framework import core

24 25 26 27 28 29

class HOOK_ACTION:
    ALL_REDUCE = 0
    REDUCE = 1


30 31 32 33 34 35 36 37 38 39 40
alignment = {
    "gpu": 256,
}

align = {
    paddle.float16.value: 2,
    paddle.bfloat16.value: 2,
    paddle.float32.value: 4,
}


41
def assign_group_by_size(parameters, group_size=128 * 1024 * 1024):
42 43 44 45 46 47 48 49 50 51 52 53 54
    is_sparse_gradient = [False] * len(parameters)

    group_indices = core.eager_assign_group_by_size(
        parameters, is_sparse_gradient, [group_size, group_size]
    )

    var_groups = OrderedDict()
    for group_idx, indices in enumerate(group_indices):
        for index in indices:
            var_groups.setdefault(group_idx, []).append(parameters[index])
    return var_groups


55 56 57
def flatten_dense_tensors(
    parameters, use_main_grad=False, fuse_param=True, warp_buffer=False
):
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import (
        GradStorage,
        ParamStorage,
    )

    _buffer_size = 0
    _param2align = {}
    dtype = parameters[0].dtype

    for param in parameters:
        assert param.trainable, "param must be trainable..."
        size = np.prod(param.shape) * align[dtype]
        remaining = size % alignment["gpu"]
        ali = 0 if remaining == 0 else alignment["gpu"] - remaining
        align_ = ali // align[dtype]
        _buffer_size += np.prod(param.shape) + align_
        _param2align[param.name] = align_

76 77 78 79 80
    if fuse_param:
        param_storage = ParamStorage(
            size=_buffer_size, dtype=dtype, device="gpu"
        )
        param_storage.add_rank_params(parameters, _param2align)
81 82 83 84 85 86 87 88 89 90 91 92 93 94

    # process gradient
    grad_dtype = paddle.float32 if use_main_grad else dtype
    grad_storage = GradStorage(
        size=_buffer_size,
        dtype=grad_dtype,
        device="gpu",
        destination="0",
        parm2align=_param2align,
    )

    for param in parameters:
        grad_storage.add_grad(param, _param2align[param.name])

95 96 97 98 99 100 101 102 103 104 105 106 107
    if warp_buffer:
        if fuse_param:
            param_storage.warp_buffer()
        grad_storage.warp_buffer()

    if fuse_param:
        if not use_main_grad:
            # param_storage --> grad_storage
            param_storage.buffer._copy_gradient_from(grad_storage.buffer)
        else:
            param_storage.buffer.main_grad = grad_storage.buffer
        param_storage.buffer.stop_gradient = False
        return param_storage, grad_storage
108
    else:
109
        return grad_storage
110 111


112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
def bw_hook_func(buffer, param):
    @paddle.autograd.no_grad()
    def fused_comm(*_):
        buffer.add_grad(param)

    return fused_comm


class FusedCommBuffer:
    def __init__(
        self,
        id,
        params,
        comm_group,
        acc_steps=1,
        act=None,
        dst=-1,
        use_main_grad=None,
        fuse_param=False,
        scale_after_comm=True,
    ):
        self._id = id
        self._params = params
        self._acc_steps = acc_steps
        self._comm_group = comm_group
        self._scale_after_comm = scale_after_comm
        self._fuse_param = fuse_param

        self.use_main_grad = (
            use_main_grad
            if use_main_grad is not None
            else hasattr(self._params[0], "main_grad")
        )

        self._task = None
        self._params_step_dict = {}
        self._params_checked_in = 0
        self._grads_to_addr = {}

        self._act = act
        if self._act == HOOK_ACTION.ALL_REDUCE:
            assert dst == -1
        elif self._act == HOOK_ACTION.REDUCE:
            assert dst != -1
        else:
            raise ValueError(
                "The act should be allreudce for dp or reduce for sharding."
            )
        self._dst = dst

        self._init_step_dict()

        if self._fuse_param:
            self.param_storage, self.grad_storage = flatten_dense_tensors(
                self._params,
                use_main_grad=use_main_grad,
                fuse_param=True,
                warp_buffer=True,
            )
            self.param_storage = self.param_storage.buffer
            self.grad_storage = self.grad_storage.buffer
        else:
            self.param_storage = None
            self.grad_storage = flatten_dense_tensors(
                self._params,
                use_main_grad=self.use_main_grad,
                fuse_param=False,
                warp_buffer=False,
            ).buffer

        self._record_addr()

    def _record_addr(self):
        for param in self._params:
            addr = (
                param.main_grad.data_ptr()
                if self.use_main_grad
                else param.grad.data_ptr()
            )
            self._grads_to_addr[param.name] = addr

    def _init_step_dict(self):
        for p in self._params:
            self._params_step_dict[p.name] = 0

    def _reset_params_checked_in(self):
        self._task = None
        self._init_step_dict()
        self._params_checked_in = 0

    @property
    def _all_params_checked_in(self):
        return (
            len(self._params) == self._params_checked_in
            and len(self._params_step_dict) == 0
        )

    def add_grad(self, param):
        assert param.name in self._params_step_dict
        current_ptr = (
            param.main_grad.data_ptr()
            if self.use_main_grad
            else param.grad.data_ptr()
        )
        if self._grads_to_addr[param.name] != current_ptr:
            raise ValueError(
                "The address of the grad/main_grad of the param has been changed during training, "
                "which is not allowed for dp/sharding overlap with pp. "
                "This may be caused by some non-inplace operations on the grad/main_grad. "
                "Please use the inplace version of the operations or disable the overlapping."
            )

        self._params_step_dict[param.name] += 1

        if self._params_step_dict[param.name] == self._acc_steps:
            self._params_checked_in += 1
            self._params_step_dict.pop(param.name)

        if self._all_params_checked_in:
            self._comm_grads()

    @imperative_base.no_grad
    def _comm_grads(self):
        assert self._all_params_checked_in

        if not self._scale_after_comm:
            scale_factor = 1.0 / self._comm_group.nranks
            self.grad_storage.scale_(scale_factor)

        if self._act == HOOK_ACTION.ALL_REDUCE:
            task = paddle.distributed.all_reduce(
                self.grad_storage, group=self._comm_group, sync_op=False
            )

        elif self._act == HOOK_ACTION.REDUCE:
            task = paddle.distributed.reduce(
                self.grad_storage,
                dst=self._dst,
                group=self._comm_group,
                sync_op=False,
            )

        self._task = task

    @imperative_base.no_grad
    def scale_grads(self):
        assert self._task is not None
        self._task.wait()

        if self._scale_after_comm:
            scale_factor = 1.0 / self._comm_group.nranks
            self.grad_storage.scale_(scale_factor)

        self._reset_params_checked_in()


def obtain_storage(
    parameters,
    use_main_grad=False,
    clip=True,
    dist=False,
    fuse_param=True,
    comm_overlap=False,
    act=None,
    comm_group=None,
    dst=-1,
    acc_steps=1,
    scale_after_comm=False,
):
281
    if len(parameters) < 1:
282
        return [], []
283

284
    var_groups = assign_group_by_size(parameters, group_size=256 * 1024 * 1024)
285
    storage = []
286
    buffers = []
287
    for group_idx, parameters in var_groups.items():
288 289
        comm_buffer = FusedCommBuffer(
            group_idx,
290
            parameters,
291 292 293 294
            comm_group=comm_group,
            acc_steps=acc_steps,
            act=act,
            dst=dst,
295
            use_main_grad=use_main_grad,
296 297
            fuse_param=fuse_param,
            scale_after_comm=scale_after_comm,
298
        )
299 300 301 302 303 304 305 306 307 308 309
        if fuse_param:
            param_buffer = comm_buffer.param_storage
            param_buffer.need_clip = clip
            param_buffer.is_distributed = dist
            storage.append(param_buffer)
        if comm_overlap:
            for param in parameters:
                param._register_backward_hook(bw_hook_func(comm_buffer, param))
            buffers.append(comm_buffer)

    return storage, buffers
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346


def filter_params(params, is_fp32, is_distributed, need_clip):
    params = list(
        filter(
            lambda x: x.is_distributed
            if is_distributed
            else (not x.is_distributed),
            params,
        )
    )
    params = list(
        filter(
            lambda x: getattr(x, 'need_clip', True)
            if need_clip
            else (not getattr(x, 'need_clip', True)),
            params,
        )
    )
    params = list(
        filter(
            lambda x: x.dtype == paddle.float32
            if is_fp32
            else x.dtype != paddle.float32,
            params,
        )
    )
    dtype = None
    for p in params:
        if dtype is None:
            dtype = p.dtype
        else:
            assert dtype == p.dtype

    return params, dtype


347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378
def fused_parameters(
    parameters,
    use_main_grad=False,
    fuse_param=True,
    comm_overlap=False,
    comm_group=None,
    dst=-1,
    acc_step=1,
    scale_after_comm=False,
):
    """
    Fuse gradients. Fuse parameters if be enabled. Prepare for comm overlap if be enabled.
    :param parameters: all parameters to be fused.
    :param use_main_grad: does the gradient use main grad or not
    :param comm_overlap: enable comm overlap or not
    :param comm_group: the comm group for comm overlap
    :param dst: the dst for comm overlap
    :param acc_step: acc steps, using for comm overlap
    :param fuse_param: fuse param or not
    :param scale_after_comm: if enable comm overlap, specify the location of grad scale
    :return: param storage if fused, comm buffers is comm overlap
    """
    g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
    act = (
        HOOK_ACTION.ALL_REDUCE if not g_shard_use_reduce else HOOK_ACTION.REDUCE
    )
    if comm_overlap:
        assert comm_group is not None
    if act == HOOK_ACTION.REDUCE:
        assert dst != -1
    elif act == HOOK_ACTION.ALL_REDUCE:
        dst = -1
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
    param_groups = []
    attrs = []

    is_fp32 = [True, False]
    is_distributed = [True, False]
    need_clip = [True, False]

    no_fp32_dtype = None
    for fp32, dist, clip in itertools.product(
        is_fp32, is_distributed, need_clip
    ):
        params, dtype = filter_params(parameters, fp32, dist, clip)
        if not fp32:
            if no_fp32_dtype is None:
                no_fp32_dtype = dtype
            elif dtype is not None:
                assert no_fp32_dtype == dtype
        attrs.append([dtype, dist, clip])
        param_groups.append(params)

    decay_fused = []
    all_fused = []
401
    all_buffers = []
402 403 404 405 406 407 408 409 410 411 412 413
    for params, attr in zip(param_groups, attrs):
        decay_params = []
        other_params = []

        for param in params:
            if not any(nd in param.name for nd in ["bias", "norm", "b_0"]):
                decay_params.append(param)
            else:
                other_params.append(param)

        is_distributed = attr[1]
        need_clip = attr[2]
414 415 416 417 418 419 420 421 422 423 424 425
        decay, decay_buffers = obtain_storage(
            decay_params,
            use_main_grad=use_main_grad,
            clip=need_clip,
            dist=is_distributed,
            fuse_param=fuse_param,
            comm_overlap=comm_overlap,
            act=act,
            comm_group=comm_group,
            dst=dst,
            acc_steps=acc_step,
            scale_after_comm=scale_after_comm,
426
        )
427 428 429 430 431 432 433 434 435 436 437 438
        other, other_buffers = obtain_storage(
            other_params,
            fuse_param=fuse_param,
            comm_overlap=comm_overlap,
            use_main_grad=use_main_grad,
            clip=need_clip,
            dist=is_distributed,
            act=act,
            comm_group=comm_group,
            dst=dst,
            acc_steps=acc_step,
            scale_after_comm=scale_after_comm,
439 440 441 442
        )
        decay_fused += decay
        all_fused += decay
        all_fused += other
443 444
        all_buffers += decay_buffers
        all_buffers += other_buffers
445

446
    return decay_fused, all_fused, all_buffers