group_sharded_optimizer_stage2.py 23.5 KB
Newer Older
B
Baibaifan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
B
Baibaifan 已提交
14 15 16 17 18 19 20 21 22 23

# The file has been adapted from fairscale file:
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/optim/oss.py
# Git commit hash: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
# We retain the following license from the original files:

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
B
Baibaifan 已提交
24 25

import logging
26 27
import warnings

B
Baibaifan 已提交
28 29 30 31 32 33
from collections import OrderedDict

import paddle
from paddle.fluid import core
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
34 35 36 37 38
from paddle.distributed import fleet, ParallelMode

HybridParallelClipGrad = (
    fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer.HybridParallelClipGrad
)
39 40 41 42 43
from paddle.distributed.collective import (
    _get_global_group,
    broadcast,
    new_group,
)
B
Baibaifan 已提交
44 45 46 47 48 49 50 51

from .group_sharded_storage import ParamStorage, GradStorage
from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad

# CUDA alignment 256 bytes, cpu alignment 4096 bytes
alignment = {"gpu": 256, "cpu": 4096}
align = {
    Type.fp16.value: 2,
52
    Type.bf16.value: 2,
B
Baibaifan 已提交
53 54 55 56 57 58
    Type.fp32.value: 4,
}


class GroupShardedOptimizerStage2(Optimizer):
    """
59
    A wrapper for Sharding Stage2 Optimizer in Dygraph.
B
Baibaifan 已提交
60 61 62 63 64 65 66

    .. warning: ShardingOptimizer encapsulates the optimization strategy and integrates it into the optimizer.

    .. ZeRO: 1.https://arxiv.org/pdf/1910.02054.pdf 2.https://arxiv.org/pdf/1910.02054.pdf.

    """

67
    # TODO (Baibaifan)
B
Baibaifan 已提交
68 69 70 71 72 73 74
    # Feature Notes:
    # 1. Unified memory for parameters and parameters.grad to InternalStorage.
    # 2. Support the segmentation of optimizer parameters and partial updating of parameters.
    # 3. Dynamically adjust training parameters and models.
    # 4. Support offload function.
    # 5. Support the establishment of independent communication groups.
    # 6. Broadcast_fp16 is not supported now.
75 76 77 78 79 80 81 82 83 84 85
    def __init__(
        self,
        params,
        optim,
        group=None,
        offload=False,
        device="gpu",
        pertrain_sync_models=True,
        dp_group=None,
        **kw
    ):
B
Baibaifan 已提交
86 87 88 89 90

        super().__init__(learning_rate=optim._learning_rate, parameters=params)
        assert core.is_compiled_with_cuda(), "Only GPU is supported now"

        # Segmentation information
91 92
        self._dtype_rank_params = (
            OrderedDict()
B
Baibaifan 已提交
93 94 95 96 97 98 99 100 101
        )  # {dtype:[param1,param2]} device, rank, params
        self._param2rank = {}
        self.__segment_params = []
        self._rank_buffer_size = {}  # {dtype: {rank: numel+alignment}}
        self._param2align = {}  # {param.name: align}

        # Default information
        self._optim = optim

102
        # sharing stage 2 comm overlap flag
103
        self._reduce_overlap = False
104 105 106
        # record the last task used for comm overlap for sharding stage 2
        self._comm_task = None

107 108 109
        assert hasattr(
            self._optim, "_master_weights"
        ), "Must use optimizer with _master_weights attribute"
B
Baibaifan 已提交
110 111 112 113 114 115 116 117 118 119

        # Support parameter group and parameter list
        self._local_params = []
        if isinstance(params[0], dict):
            for param_group in params:
                self._local_params.extend(list(param_group["params"]))
        else:
            self._local_params.extend(list(params))

        self._default_device = device
120 121 122 123 124 125 126 127 128 129 130
        self._pfp16 = (
            len(
                list(
                    filter(
                        lambda x: x.trainable and x.dtype == Type.fp16.value,
                        self._local_params,
                    )
                )
            )
            > 0
        )
B
Baibaifan 已提交
131

132 133 134 135 136 137 138
        self._broadcast_overlap = False
        self._forward_pre_hook_remove_helper = []
        try:
            # The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
            # Have to sort the params to make sure all params are in the forward using order.
            self._broadcast_order_params = sorted(
                self.local_params,
139 140
                key=lambda x: int(x.name.split('.')[0].split('_')[-1]),
            )
141 142 143
        except ValueError:
            self._broadcast_order_params = None

144 145 146
        self._group = (
            new_group(_get_global_group().ranks) if group is None else group
        )
B
Baibaifan 已提交
147

148 149
        # only support to combine stage2 and dp hybrid parallel now.
        self._dp_group = dp_group
B
Baibaifan 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        self.world_size = self._group.nranks
        self._rank = self._group.rank
        self._global_root_rank = self._group.ranks[0]

        # Synchronous all ranks models
        if pertrain_sync_models:
            self._sync_params_and_buffers()

        self.param_storages = {}  # {dtype: {rank: InternalStorage}}

        if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
            logging.warning(
                "While using ClipGradByGlobalNorm in GroupShardedOptimizerStage2, the grad clip of original optimizer will be changed."
            )

165 166 167 168 169 170 171 172 173 174 175 176
            hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
            if (
                hcg
                and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
            ):
                self._optim._grad_clip = HybridParallelClipGrad(
                    self._optim._grad_clip, hcg
                )
            else:
                self._optim._grad_clip = GroupShardedClipGrad(
                    self._optim._grad_clip, paddle.get_device(), self._group
                )
B
Baibaifan 已提交
177
            if self._optim._parameter_list and isinstance(
178 179
                self._optim._parameter_list[0], dict
            ):
B
Baibaifan 已提交
180 181 182 183 184
                for item in self._optim._param_groups:
                    if "grad_clip" in item.keys():
                        item["grad_clip"] = self._optim._grad_clip

        if offload:
185 186 187
            assert (
                self._pfp16
            ), "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16"
B
Baibaifan 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208

        self.offload = offload  # Using for offload
        self.offload_device = "cpu"
        self.offload_buffer_size = 0
        self.offload_param2align = {}
        self.offload_params = None
        self.offload_grads = None
        self.dev_id = int(paddle.get_device().split(":")[1])

        self._master_params = {}

        # Update optimizer parameters and adjust parameter storage and use according to rank.
        self._update_opt_status()

    @paddle.autograd.no_grad()
    def _sync_params_and_buffers(self):
        """
        Sync all model states for all ranks
        """

        for p in self._local_params:
209 210 211
            broadcast(
                p, src=self._global_root_rank, group=self._group, sync_op=True
            )
B
Baibaifan 已提交
212

213
            if self._dp_group:
214 215 216 217 218 219
                broadcast(
                    p,
                    src=self._dp_group.ranks[0],
                    group=self._dp_group,
                    sync_op=True,
                )
220

221
    def _update_task(self, task):
222
        if self._reduce_overlap:
223 224 225 226 227 228
            assert task is not None
        # Only track of the last reduce task.
        # Since all tasks are on the same stream, only need to wait the last one.
        # After waiting for the last reduce task, all reduce tasks before have already finished.
        self._comm_task = task

229 230 231 232
    def _set_reduce_overlap(self, reduce_overlap):
        # Enable gradients' reduces overlap with backward calculation.
        self._reduce_overlap = reduce_overlap

233 234 235
    def _set_broadcast_overlap(
        self, broadcast_overlap, layers=None, num_groups=None
    ):
236 237 238
        # Enable post optimizer broadcasts overlap with the forward calculation of next batch.
        self._broadcast_overlap = broadcast_overlap
        if self._broadcast_overlap:
239 240 241
            assert (
                layers is not None
            ), "To enable broadcast overlap forward, please pass the module to the function."
242 243 244 245 246 247 248 249
            self._layers = layers
            warnings.warn(
                "Setting overlap broadcast means the `paddle.device.cuda.synchronize()` "
                "must be called manually before calling `paddle.save()` and before and inference."
            )
            if self._broadcast_order_params is None:
                # Params' names should be like column_linear_32.w_0 patter to get the best performance.
                warnings.warn(
250
                    r"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
251 252
                    "overlap broadcast may harm the performance."
                )
253
                self._broadcast_order_params = self._local_params
254

255 256 257 258 259 260 261
        if num_groups is None or num_groups > len(self._broadcast_order_params):
            warnings.warn(
                "The num_groups for broadcast is larger than the number of params to be broadcast. "
                "It will set to default value: 1 (use the default sharding group)."
            )
            num_groups = 1

262 263 264
        assert (
            isinstance(num_groups, int) and num_groups > 0
        ), "num_groups should be a positive integer"
265 266 267 268 269 270 271 272 273 274 275

        self._number_of_broadcast_groups = num_groups
        self._broadcast_groups = [
            None for _ in range(self._number_of_broadcast_groups)
        ]
        self._broadcast_groups[0] = self._group

        ranks = self._group.ranks
        for i in range(1, self._number_of_broadcast_groups):
            self._broadcast_groups[i] = new_group(ranks)

B
Baibaifan 已提交
276 277 278 279 280 281 282 283
    def _generate_master_params(self, trainable_params):
        if self.offload:
            for param in trainable_params:
                if param.name not in self._master_params.keys():
                    self._master_params[param.name] = core.eager.Tensor(
                        name=param.name,
                        value=param.cast(dtype=Type.fp32.value).numpy(),
                        place=core.CPUPlace(),
284 285
                        stop_gradient=param.stop_gradient,
                    )
B
Baibaifan 已提交
286 287 288 289 290 291 292 293
        else:
            for param in trainable_params:
                if param.dtype == Type.fp16.value:
                    master_tensor = paddle.cast(param, Type.fp32.value)
                    master_tensor.name = param.name
                    self._optim._master_weights[param.name] = master_tensor

    def _update_opt_status(self):
294
        """Update optimizer status and parameter storage information, and special functions to be developed."""
B
Baibaifan 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
        # func 1
        self._integration_params()

    # Segement helpers

    def _segment_params(self):
        """
        Divide all optimizer parameters equally into rank.
        """
        if len(self.__segment_params) == 0:
            self.__segment_params, param_lists = [
                [] for _ in range(self.world_size)
            ], [[] for _ in range(self.world_size)]
            sizes = [0] * self.world_size
            for param in self._local_params:
                # Add this param to rank with smallest size.
                rank = sizes.index(min(sizes))
                param_lists[rank].append(param)

                # Statistical real numels
                sizes[rank] += param._numel() if param.trainable else 0

            for rank, params in enumerate(param_lists):
                self.__segment_params[rank].extend(params)
        return self.__segment_params

    @property
    def local_params(self):
        return self._local_params

    @property
    def param2rank(self):
        """Map the params to the rank which owns them"""
        if len(self._param2rank) == 0:
            for rank, params in enumerate(self._segment_params()):
                for param in params:
                    self._param2rank[param.name] = rank
        return self._param2rank

    @property
    def dtype_rank_params(self):
        """
        Divide the parameters into groups according to rank and dtype.
        """
        if len(self._dtype_rank_params) == 0:
            # Assign the parameters of each rank according to the type
341
            trainable_params = list(
342 343
                filter(lambda x: x.trainable, self._local_params)
            )
344
            for param in trainable_params:
B
Baibaifan 已提交
345
                if param.dtype not in self._dtype_rank_params.keys():
346 347 348
                    self._dtype_rank_params[param.dtype] = [
                        [] for _ in range(self.world_size)
                    ]
349 350 351
                self._dtype_rank_params[param.dtype][
                    self.param2rank[param.name]
                ].append(param)
B
Baibaifan 已提交
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

            # Sort per rank params by size
            for dtype in self._dtype_rank_params.keys():
                for rank_params in self._dtype_rank_params[dtype]:
                    rank_params.sort(key=lambda x: x._numel())

        return self._dtype_rank_params

    @property
    def rank_buffer_size(self):
        """
        Count the memory size of the parameters corresponding to rank under the corresponding dtype.
        """
        # CUDA alignment 256 bytes
        if len(self._rank_buffer_size) == 0:
            for dtype in self.dtype_rank_params.keys():
                if dtype not in self._rank_buffer_size.keys():
                    self._rank_buffer_size[dtype] = {}
                for dst_rank, per_rank_params in enumerate(
371 372
                    self.dtype_rank_params[dtype]
                ):
B
Baibaifan 已提交
373 374 375 376 377 378 379
                    if dst_rank not in self._rank_buffer_size[dtype].keys():
                        self._rank_buffer_size[dtype][dst_rank] = 0
                    for param in per_rank_params:
                        if not param.trainable:
                            continue
                        size = param._numel() * align[dtype]
                        remaining = size % alignment[self._default_device]
380 381 382 383 384
                        ali = (
                            0
                            if remaining == 0
                            else alignment[self._default_device] - remaining
                        )
B
Baibaifan 已提交
385
                        align_ = ali // align[dtype]
386 387 388
                        self._rank_buffer_size[dtype][dst_rank] += (
                            param._numel() + align_
                        )
B
Baibaifan 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
                        self._param2align[param.name] = align_

        return self._rank_buffer_size

    def _integration_params(self):
        """
        Integrate the parameters into a continuous memory according to rank, and support the update of training parameters.
        """

        for dtype, per_rank_params in self.dtype_rank_params.items():
            if dtype not in self.param_storages.keys():
                self.param_storages[dtype] = {}

            for dst_rank, params in enumerate(per_rank_params):
                if len(params) > 0:

                    # Merge all the trainable params in a single InternalStorage
                    trainable_params = list(
407 408
                        filter(lambda x: x.trainable, params)
                    )
B
Baibaifan 已提交
409 410 411 412 413 414
                    if self._pfp16 and dst_rank == self._rank:
                        self._generate_master_params(trainable_params)
                    if trainable_params:
                        param_storage = ParamStorage(
                            size=self.rank_buffer_size[dtype][dst_rank],
                            dtype=dtype,
415 416
                            device=self._default_device,
                        )
B
Baibaifan 已提交
417

418 419 420
                        param_storage.add_rank_params(
                            trainable_params, self._param2align
                        )
B
Baibaifan 已提交
421 422 423 424 425
                        self.param_storages[dtype][dst_rank] = param_storage

        # Clear the InternalStorage keys which are not in use anymore
        dtype_in_use = list(self.dtype_rank_params.keys())
        dtype_to_pop = list(
426 427
            filter(lambda x: x not in dtype_in_use, self.param_storages.keys())
        )
B
Baibaifan 已提交
428 429 430 431 432 433 434 435 436
        for d in dtype_to_pop:
            self.param_storages.pop(d)

        if self.offload:
            self._optim._master_weights = self._master_params
            cpu_master_params = [p for p in self._master_params.values()]
            for param in cpu_master_params:
                size = param._numel() * align[Type.fp32.value]
                remaining = size % alignment[self.offload_device]
437 438 439 440 441
                ali = (
                    0
                    if remaining == 0
                    else alignment[self.offload_device] - remaining
                )
B
Baibaifan 已提交
442 443 444 445 446 447 448 449 450
                align_ = ali // align[Type.fp32.value]
                self.offload_buffer_size += param._numel() + align_
                self.offload_param2align[param.name] = align_

            if cpu_master_params:
                with device_guard(self._rank, self.offload_device):
                    self.offload_params = ParamStorage(
                        size=self.offload_buffer_size,
                        dtype=Type.fp32.value,
451 452
                        device=self.offload_device,
                    )
B
Baibaifan 已提交
453 454
                    self.offload_params.buffer.name = "offload_buffer"
                    self.offload_params.add_rank_params(
455 456
                        cpu_master_params, self.offload_param2align, False
                    )
B
Baibaifan 已提交
457 458 459 460 461 462 463 464
                    self.offload_params.buffer.stop_gradient = False

                    self.offload_grads = GradStorage(
                        size=self.offload_buffer_size,
                        dtype=Type.fp32.value,
                        device=self.offload_device,
                        destination=self._rank,
                        parm2align=self.offload_param2align,
465 466
                        convert_cpu=True,
                    )
B
Baibaifan 已提交
467 468
                    for p in cpu_master_params:
                        self.offload_grads.add_grad(
469 470
                            p, self.offload_param2align[p.name]
                        )
B
Baibaifan 已提交
471 472

                    self._optim._master_weights[
473 474
                        self.offload_params.buffer.name
                    ] = self.offload_params.buffer
B
Baibaifan 已提交
475 476 477 478 479 480 481

    def _offload_acc_grad(self, param_name, grad_fp32_cpu):
        """accumulate grads with offload strategy"""
        with device_guard(self._rank, self.offload_device):
            if param_name in self._master_params.keys():
                if self._master_params[param_name].grad is None:
                    self._master_params[param_name]._copy_gradient_from(
482 483
                        grad_fp32_cpu
                    )
B
Baibaifan 已提交
484 485 486 487
                else:
                    self._master_params[param_name].grad.add_(grad_fp32_cpu)

        self.offload_params.buffer._copy_gradient_from(
488 489
            self.offload_grads.buffer
        )
B
Baibaifan 已提交
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504

    def _offload_scale_grad(self, scale_size):
        """scale grads with offload strategy"""
        with device_guard(self._rank, self.offload_device):
            self.offload_grads.buffer.scale_(scale=scale_size)

    def _offload_clear_grad(self):
        """clear grads with offload strategy"""
        with device_guard(self._rank, self.offload_device):
            self.offload_grads.buffer.zero_()

    def step(self):
        """
        A wrapper for Optimizer's step function to finish the update operation of the optimizer.
        """
505 506
        # This method won't be called directly by opt.step()!
        # The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
507 508 509 510 511 512
        if self._broadcast_overlap:
            # Clear the pre forward hook in the optimizer step.
            for hook_remove in self._forward_pre_hook_remove_helper:
                hook_remove.remove()
            self._forward_pre_hook_remove_helper = []

B
Baibaifan 已提交
513 514 515
        if self.offload:
            params_list = [self.offload_params.buffer]

516
            # TODO(Baibaifan): Offload will support param_groups later
B
Baibaifan 已提交
517 518 519 520 521 522 523 524 525 526 527
            if not isinstance(self._optim._param_groups[0], dict):
                self._optim._parameter_list = params_list
                self._optim._param_groups = params_list

        # Run the optimizer of the current rank step
        if self.offload:
            with device_guard(device=self.offload_device):
                self._optim.step()

            for param in self._local_params:
                if param.name in self._master_params.keys():
528 529 530 531 532
                    param.set_value(
                        self._master_params[param.name]
                        .cuda(self.dev_id)
                        .cast(dtype=param.dtype)
                    )
B
Baibaifan 已提交
533 534 535 536 537 538 539 540
        else:
            self._optim.step()

        # Synchronize all the updated shards in between the ranks
        self._broadcast_params()

    def minimize(self):
        raise RuntimeError(
541 542
            "optimizer.minimize() not support now, please use optimizer.step()"
        )
B
Baibaifan 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559

    def set_state_dict(self, state_dict):
        self._optim.set_state_dict(state_dict)

    def state_dict(self):
        return self._optim.state_dict()

    def _clear_cache(self):
        self.__segment_params.clear()
        self._dtype_rank_params.clear()
        self._param2rank.clear()

    @paddle.autograd.no_grad()
    def _broadcast_params(self):
        """Broadcast the parameters of the current rank to each rank"""

        # Exchange all the shards with the other ranks
560 561 562 563 564
        if self._broadcast_overlap:
            self._broadcast_params_overlap_forward()
        else:
            for dtype_per_rank in self.param_storages.values():
                for dst_rank, internal_storage in dtype_per_rank.items():
565 566 567 568 569 570
                    broadcast(
                        tensor=internal_storage.buffer,
                        src=self._group.ranks[dst_rank],
                        group=self._group,
                        sync_op=True,
                    )
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585

    def _forward_pre_hook_function(self, tasks):
        # Since the layers will call pre hook by `forward_pre_hook(self, inputs)`,
        # the helper functions needs the x and y to take those params.
        def __impl__(x, y):
            for task in tasks:
                # Wait for broadcast task before using the result of the broadcast.
                task.wait()

        return __impl__

    @paddle.autograd.no_grad()
    def _broadcast_params_overlap_forward(self):
        # Exchange all the shards with the other ranks,
        # but overlap the broadcast with next batch's calculation.
586 587
        group_idx = 0

588 589 590
        param2task = {}
        for x in self._broadcast_order_params:
            if x.trainable:
591 592
                group = self._broadcast_groups[group_idx]
                group_idx = (group_idx + 1) % self._number_of_broadcast_groups
593 594 595 596 597 598
                task = broadcast(
                    tensor=x,
                    src=group.ranks[self._param2rank[x.name]],
                    group=group,
                    sync_op=False,
                )
599 600 601 602 603 604 605 606 607 608 609 610 611
                assert x.name not in param2task
                param2task[x.name] = task

        for layer in self._layers.sublayers():
            if len(layer.sublayers()) == 0:
                # Register forward pre hood for leaf layers. This will get the best performance.
                tasks = []
                for param in layer.parameters():
                    if param.trainable:
                        if param.name in param2task:
                            tasks.append(param2task[param.name])
                self._forward_pre_hook_remove_helper.append(
                    layer.register_forward_pre_hook(
612 613 614
                        self._forward_pre_hook_function(tasks)
                    )
                )