sharding_stage3.py 33.7 KB
Newer Older
B
Baibaifan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import copy
import time
import contextlib
import logging
import functools
import numpy as np
from itertools import chain
from functools import reduce
from types import MethodType
from collections import deque, OrderedDict

import paddle
from paddle import nn
from paddle.autograd import PyLayer
import paddle.fluid.core as core
import paddle.distributed as dist
from paddle.fluid.framework import ParamBase
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group

B
Baibaifan 已提交
36
from .sharding_utils import Type, ShardingClipGrad, device_guard
B
Baibaifan 已提交
37
from ..pp_utils.utils import _all_gather
B
Baibaifan 已提交
38
from ...utils.internal_storage import GradStorage
B
Baibaifan 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59

# CUDA alignment 256 bytes
alignment = {"gpu": 256, }
align = {
    Type.fp16.value: 2,
    Type.fp32.value: 4,
}

global CHECK_LAYER
CHECK_LAYER = dict()  # Help to check layer's id -> layer's name


class ShardingStage3(nn.Layer):
    """ 
    A wrapper for Sharding Stage3 Layer in Dygraph. 

    .. warning: ShardingStage3 encapsulates the layer strategy and integrates it into the nn.Layer.

    .. ZeRO: https://arxiv.org/pdf/1910.02054.pdf.
    """

B
Baibaifan 已提交
60 61 62 63 64 65 66
    # TODO (Baibaifan) 
    # Feature Notes::
    # 1. The model supports the segmentation of parameters by global ranks in layers.
    # 2. Support communication flow and computing flow.
    # 3. Support offload function.
    # 4. Support the establishment of independent communication groups.

B
Baibaifan 已提交
67 68 69 70 71 72
    def __init__(self,
                 layer,
                 optimizer,
                 group=None,
                 sync_buffers=False,
                 device="gpu",
B
Baibaifan 已提交
73
                 segment_size=2**15,
B
Baibaifan 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87
                 pertrain_sync_models=True,
                 accumulate_grads=False,
                 offload=False,
                 sync_comm=False):
        super().__init__()

        # Default configs
        assert core.is_compiled_with_cuda(), "Only support CUDA."
        self._layer = layer
        self._default_device = device
        self.__sync_buffers = sync_buffers
        self._accumulate_grads = accumulate_grads
        self._offload = offload
        self._sync_comm = sync_comm
B
Baibaifan 已提交
88
        # segmentation size
89
        self._segment_size = segment_size
B
Baibaifan 已提交
90

B
Baibaifan 已提交
91 92 93 94 95 96 97 98 99
        global DEV
        DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
        ).split(":")[0]
        global DEV_ID
        DEV_ID = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
                                                            .split(":")[1])
        global param2dtype
        param2dtype = dict()

B
Baibaifan 已提交
100 101 102 103 104 105 106 107
        # Communication group establishment
        self._group = dist.new_group(_get_global_group()
                                     .ranks) if group is None else group
        self._world_size_scaling = 1.0 / self._group.nranks
        assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1."
        self._rank = self._group.rank
        self._global_root_rank = 0  # picking rank 0 as the reference
        self._global_ranks = self._group.ranks
B
Baibaifan 已提交
108 109 110

        # Parameter segmentation for global ranks
        # After flatten -> self._param2buffer_size, self._param2buffer, self._trainable_params
B
Baibaifan 已提交
111 112 113
        self._param2buffer_size = dict()  # {param.name: size}
        self._param2buffer = dict(
        )  # {param.name: [(start0, end0),(start1, end1), ...]}
B
Baibaifan 已提交
114 115 116 117
        self._trainable_params = dict()  # {id(layer): [trainable_params]}
        self._unslice_params = set()  # param's numel <= segment_size
        self._unslice_params2align = dict()  # {param.name: param's align}
        self._grad_storages = dict()  # {param.dtype: GradStorage}
B
Baibaifan 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

        assert not isinstance(
            optimizer, list), "Multiple optimizers are not supported now."
        self._optim = _OptimizerWrapper(optimizer, self._offload, self._group,
                                        self._update_params_slice)
        self._ori_parameter_list = self._optim._parameter_list
        self._ori_param_groups = self._optim._param_groups

        # Replace optimizer's _grad_clip
        if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
            logging.warning(
                "While using ClipGradByGlobalNorm in ShardingStage3, the grad clip of original optimizer will be changed."
            )
            self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip,
                                                      paddle.get_device(),
                                                      self._group)

        # Synchronous all ranks models
        if pertrain_sync_models:
            self._sync_params_and_buffers()

        self._segment_rank_params(self._layer)

B
Baibaifan 已提交
141 142 143
        # Add unslice params to master_weight in fp16
        self._handle_unslice_params()

B
Baibaifan 已提交
144 145 146
        # In the first step, record the execution order of the layer
        self._order_tracer = OrderedDict()
        self._order_tracer["order"] = 0
B
Baibaifan 已提交
147
        self._order_tracer["layer"] = list()
B
Baibaifan 已提交
148

B
Baibaifan 已提交
149 150
        # Register task flow
        self._task_flow = TaskFlow()
B
Baibaifan 已提交
151

B
Baibaifan 已提交
152 153
        # Register forward hooks
        self._register_forward_hooks(self._layer)
B
Baibaifan 已提交
154

B
Baibaifan 已提交
155 156
        # Register backward parameter hooks
        self._register_backward_hooks()
B
Baibaifan 已提交
157

B
Baibaifan 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        # Redefine optimizer step and clear function
        self._redefine_opt_step()
        self._redefine_opt_clear()

    @paddle.no_grad()
    def _sync_params_and_buffers(self):
        """
        Sync all model states for all ranks
        """

        for p in self._layer.parameters():
            dist.broadcast(
                p,
                src=self._global_root_rank,
                group=self._group,
                use_calc_stream=True)

        # Multi stream operation will be supported later
        dist.wait(tensor=p, group=self._group, use_calc_stream=True)

    def _clear_gradients(self):
        assert len(self._trainable_params.keys()) > 0
        current_layer_params = self._layer.parameters(include_sublayers=True)
B
Baibaifan 已提交
181
        # 1.Handle param's slice
B
Baibaifan 已提交
182
        trainable_params = list(
B
Baibaifan 已提交
183 184
            filter(lambda p: p.trainable and p not in self._unslice_params,
                   current_layer_params))
B
Baibaifan 已提交
185 186 187 188 189 190 191 192
        for param in trainable_params:
            assert hasattr(
                param, "fw_storage"
            ), "Find {} don't have fw_storage attribute.".format(param.name)

            param.fw_storage.clear_gradient(False)
            param.fw_storage._gradient_set_empty(False)
            param.bw_storage._clear()
B
Baibaifan 已提交
193
        # 2.Handle unslice param
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
        if not self._offload:
            for grad_storage in self._grad_storages.values():
                grad_storage.buffer.zero_()
        else:
            for param in list(self._unslice_params):
                param.clear_gradient(False)
                param._gradient_set_empty(False)
                tmp_var = param.cuda(DEV_ID)
                param._clear()
                if tmp_var.dtype == Type.fp32.value and param2dtype[
                        param.name] == Type.fp16.value:
                    tmp_var = paddle.cast(tmp_var, Type.fp16.value)
                tmp_var._share_buffer_to(param)
                tmp_var._clear()
            for grad_storage in self._grad_storages.values():
                grad_storage.manumal_relase()
                grad_storage.rebuild()
B
Baibaifan 已提交
211 212 213 214 215 216 217

    # Update param memery slice
    def _update_params_slice(self):
        update_list = self._update_params()

        if not isinstance(self._optim._param_groups[0], dict):
            slice_params = [param.fw_storage for param in update_list]
B
Baibaifan 已提交
218 219 220 221
            self._optim._parameter_list = slice_params + list(
                self._unslice_params)
            self._optim._param_groups = slice_params + list(
                self._unslice_params)
B
Baibaifan 已提交
222 223
        else:
            params_name_list = list(map(lambda p: p.name, update_list))
B
Baibaifan 已提交
224 225
            fw_storage_name_list = list(
                map(lambda p: p.fw_storage.name, update_list))
B
Baibaifan 已提交
226
            for param_group in self._optim._param_groups:
B
Baibaifan 已提交
227
                p_group = []
B
Baibaifan 已提交
228 229
                for p in param_group['params']:
                    if p.name in params_name_list:
B
Baibaifan 已提交
230 231 232 233 234 235 236
                        p_group.append(p.fw_storage)
                    elif p.name in fw_storage_name_list:
                        p_group.append(update_list[fw_storage_name_list.index(
                            p.name)].fw_storage)
                    elif p in self._unslice_params:
                        p_group.append(p)
                param_group['params'] = p_group
B
Baibaifan 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250

    def forward(self, *inputs, **kwargs):
        """
        A wrapper for Sharding Stage3 layer.
        """
        # 1.Sync layer's buffers state
        if self.__sync_buffers:
            self._sync_buffers()

        # 2.Normal FW on the base model
        fw = self._layer(*inputs, **kwargs)

        return fw

B
Baibaifan 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
    def _handle_unslice_params(self):
        buffer_size = dict()
        buffer_size[Type.fp32.value] = 0
        buffer_size[Type.fp16.value] = 0
        for param in self._unslice_params:
            # Updata optimizer master weights
            if param.dtype == Type.fp16.value and not self._offload:
                self._optim._master_weights[param.name] = paddle.cast(
                    param, Type.fp32.value)
            param2dtype[param.name] = param.dtype
            p_align = self._param2align(param)
            self._unslice_params2align[param.name] = p_align
            buffer_size[param.dtype] += param._numel() + p_align

        # Create unslice_params'grad
        for param in sorted(list(self._unslice_params), key=lambda p: p.name):
            if param.dtype not in self._grad_storages.keys():
                self._grad_storages[param.dtype] = GradStorage(
                    buffer_size[param.dtype],
                    dtype=param.dtype,
                    device=self._default_device,
                    destination=self._rank,
                    parm2align=self._unslice_params2align)
            self._grad_storages[param.dtype].add_grad(
                param, self._unslice_params2align[param.name])

B
Baibaifan 已提交
277
    def _segment_rank_params(self, layer, name="last_layer"):
B
Baibaifan 已提交
278 279 280
        """
        Flatten parameters according to layer.
        """
B
Baibaifan 已提交
281 282 283 284 285 286 287 288 289
        current_layer_params = _current_layer_params(layer)
        if current_layer_params:
            CHECK_LAYER[id(layer)] = name
            self._flatten_layer_params(layer, current_layer_params)

        for name, sub_layer in layer.named_children():
            self._segment_rank_params(sub_layer, name)

    def _flatten_layer_params(self, layer, current_layer_params):
B
Baibaifan 已提交
290 291 292 293
        """
        Parameter segmentation and memory integration.
        """

B
Baibaifan 已提交
294 295 296
        def _add_manage_info(trainable_param):
            return _PartitionParam(trainable_param)

B
Baibaifan 已提交
297 298 299 300 301 302 303
        current_params = list()
        for p in current_layer_params:
            if p.trainable and p._numel() > self._segment_size:
                current_params.append(_add_manage_info(p))
            elif p.trainable:
                self._unslice_params.add(_UnsliceParam(p))

B
Baibaifan 已提交
304
        assert id(layer) not in self._trainable_params.keys()
B
Baibaifan 已提交
305
        self._trainable_params[id(layer)] = current_params
B
Baibaifan 已提交
306 307 308 309 310 311

        for param in self._trainable_params[id(layer)]:
            if param.name in self._param2buffer.keys():
                continue
            self._param2buffer[param.name] = []
            # 1.Params alignment
B
Baibaifan 已提交
312
            align_ = self._param2align(param)
B
Baibaifan 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328

            offset = align_ + param._numel()
            buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - (
                offset % self._group.nranks)
            self._param2buffer_size[param.name] = buffer_size

            # 2.Combination param buffer
            assert buffer_size % self._group.nranks == 0
            pre_buffer = buffer_size // self._group.nranks

            for rank_ in range(self._group.nranks):
                self._param2buffer[param.name].append(
                    (rank_ * pre_buffer, (rank_ + 1) * pre_buffer))

            # 3.Flatten layer params and release other rank buffer
            self._param_storage(param, buffer_size)
B
Baibaifan 已提交
329 330
            # Record param's dtype
            param2dtype[param.name] = param.dtype
B
Baibaifan 已提交
331 332

    def _param_storage(self, param, buffer_size):
B
Baibaifan 已提交
333 334 335
        """
        This is a function to simplify the handling of parameter InternalStorages.
        """
B
Baibaifan 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
        assert isinstance(buffer_size, int)
        value = np.zeros(
            buffer_size,
            dtype=np.float16) if Type.fp16.value == param.dtype else np.zeros(
                buffer_size, dtype=np.float32)
        buffer = core.VarBase(value=value, place=core.CPUPlace())

        param_shape = param.shape
        origin_state = param.stop_gradient
        param.stop_gradient = True
        param.flatten_()
        param.stop_gradient = origin_state
        start, end = self._param2buffer[param.name][self._rank]

        # Copy the current param value
        tmp_var = core.VarBase(
            tensor=buffer._slice(0, param._numel()), place=core.CPUPlace())
        param_cpu = param.cpu()
        tmp_var.value().get_tensor().set(param_cpu.value().get_tensor(),
                                         core.CPUPlace())
        param.value().get_tensor()._set_dims(param_shape)
        param._clear()

        # Current rank param_storage
B
Baibaifan 已提交
360 361 362 363 364 365 366
        if self._offload:
            param.fw_storage = core.VarBase(
                buffer._slice(start, end),
                core.CPUPlace(), "slice@" + param.name)
        else:
            param.fw_storage = core.VarBase(
                buffer._slice(start, end), "slice@" + param.name)
B
Baibaifan 已提交
367 368 369
        param.status = "part"

        # Updata optimizer master weights
B
Baibaifan 已提交
370
        if param.dtype == Type.fp16.value and not self._offload:
B
Baibaifan 已提交
371 372 373 374
            self._optim._master_weights[param.fw_storage.name] = paddle.cast(
                param.fw_storage, Type.fp32.value)

    def _register_forward_hooks(self, layer):
B
Baibaifan 已提交
375 376 377 378 379 380 381 382 383 384
        """
        Register pylayer to manage memory slices.
        There are four stages:
        FW
        1. Before the forward layers, synchronize the full parameters.
        2. After the forward layers, release the full parameter and keep the parameter slice.
        BW
        3. Before the backward layers, synchronize the full parameters and create param's grad.
        4. After the gradient accumulation, release the full parameter and keep the parameter slice.
        """
B
Baibaifan 已提交
385 386 387 388 389 390 391 392 393 394 395 396
        current_layer_params = _current_layer_params(layer)
        if current_layer_params:
            self._register_forward_all_hooks(layer, self._task_flow)

        for _, sub_layer in layer.named_children():
            self._register_forward_hooks(sub_layer)

    def _register_forward_all_hooks(self, sub_layer, task_flow):
        def _forward_pre_hook(layer, inputs):
            return ForwardPreHooks(layer, self._order_tracer,
                                   self._trainable_params, self._param2buffer,
                                   self._rank, self._group, self._sync_comm,
B
Baibaifan 已提交
397
                                   self._offload, task_flow)
B
Baibaifan 已提交
398 399 400 401 402

        def _forward_post_hook(layer, inputs, outputs):
            return ForwardPostHooks.apply(
                outputs, layer, self._order_tracer, self._trainable_params,
                self._param2buffer, self._param2buffer_size, self._rank,
B
Baibaifan 已提交
403
                self._group, self._sync_comm, self._offload, task_flow)
B
Baibaifan 已提交
404 405 406 407 408 409 410 411 412

        # register previous forward hooks
        sub_layer.register_forward_pre_hook(_forward_pre_hook)

        # register post forward hooks
        sub_layer.register_forward_post_hook(_forward_post_hook)

    @paddle.no_grad()
    def _sync_buffers(self):
B
Baibaifan 已提交
413 414 415 416
        """
        Sync all the param buffers from all ranks (exp: batch norm statistics).
        """

B
Baibaifan 已提交
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
        for buffer in self._layer.buffers(include_sublayers=True):
            dist.broadcast(
                buffer,
                self._global_root_rank,
                self._group,
                use_calc_stream=True)
        # Multi stream operation will be supported later
        dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)

    def __getattr__(self, name):
        """Forward missing attributes to wrapped layer."""
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self._layer, name)

    def _update_params(self):
B
Baibaifan 已提交
434 435 436
        """
        Update parameters to optimizer memory slice.
        """
B
Baibaifan 已提交
437 438 439 440
        update_list = []
        assert len(self._trainable_params.keys()) > 0
        current_layer_params = self._layer.parameters(include_sublayers=True)
        trainable_params = list(
B
Baibaifan 已提交
441 442 443
            filter(lambda p: p.trainable and p not in self._unslice_params,
                   current_layer_params))
        # 1.Handle param's slice
B
Baibaifan 已提交
444 445 446 447 448 449 450
        for param in trainable_params:
            assert hasattr(
                param,
                "fw_storage"), "Find {} don't have fw_storage attribute".format(
                    param.name)

            if self._accumulate_grads:
B
Baibaifan 已提交
451 452 453 454 455
                if self._offload:
                    with device_guard(device="cpu"):
                        param.bw_storage.scale_(scale=self._world_size_scaling)
                else:
                    param.bw_storage.scale_(scale=self._world_size_scaling)
B
Baibaifan 已提交
456
            param.fw_storage = _VarBaseWrapper(param)
B
Baibaifan 已提交
457
            assert param.fw_storage.grad is None
B
Baibaifan 已提交
458 459
            param.fw_storage._copy_gradient_from(param.bw_storage)
            update_list.append(param)
B
Baibaifan 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472

        # 2.Handle unslice param
        for grad_storage in self._grad_storages.values():
            grad_storage.buffer.scale_(scale=self._world_size_scaling)
            dist.all_reduce(
                tensor=grad_storage.buffer,
                group=self._group,
                use_calc_stream=True)
            dist.wait(
                tensor=grad_storage.buffer,
                group=self._group,
                use_calc_stream=True)

473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
        if self._offload:
            for param in list(self._unslice_params):
                tmp_var = _device2cpu(param, convert_dtype=True)
                tmp_var._share_buffer_to(param)
                tmp_var._clear()

            for grad_storage in self._grad_storages.values():
                for p in grad_storage._params:
                    tmp_g = _device2cpu(p.grad, convert_dtype=True)
                    p.clear_gradient(False)
                    p._gradient_set_empty(False)
                    p._copy_gradient_from(tmp_g)
                    tmp_g._clear()
                grad_storage.buffer._clear()

B
Baibaifan 已提交
488 489
        return update_list

B
Baibaifan 已提交
490 491 492 493
    def get_all_parameters(self, convert2cpu=False):
        """
        Get the full parameters and return the corresponding task flows.
        """
B
Baibaifan 已提交
494 495 496
        assert len(self._trainable_params.keys()) > 0
        current_layer_params = self._layer.parameters(include_sublayers=True)
        trainable_params = list(
B
Baibaifan 已提交
497 498
            filter(lambda p: p.trainable and p not in self._unslice_params,
                   current_layer_params))
B
Baibaifan 已提交
499 500 501 502 503 504 505 506 507
        t_flow = _allgather_buffer(
            trainable_params,
            self._group,
            use_calc_stream=True,
            task_flow=TaskFlow(),
            sync_wait=True,
            offload=self._offload,
            convert2cpu=convert2cpu)
        if convert2cpu:
B
Baibaifan 已提交
508
            for param in trainable_params:
B
Baibaifan 已提交
509
                t_flow.full_param[param.name]._share_buffer_to(param)
B
Baibaifan 已提交
510 511 512 513 514 515 516

        self._optim._parameter_list = self._ori_parameter_list
        self._optim._param_groups = self._ori_param_groups

    def _register_backward_hooks(self):
        current_layer_params = self._layer.parameters(include_sublayers=True)
        trainable_params = list(
B
Baibaifan 已提交
517 518
            filter(lambda p: p.trainable and p not in self._unslice_params,
                   current_layer_params))
B
Baibaifan 已提交
519 520 521 522 523 524 525 526 527 528

        for param in trainable_params:
            allreduce_function = self._get_allreduce_fn(param)
            param._register_backward_hook(allreduce_function)

    def _get_allreduce_fn(self, param):
        @paddle.no_grad()
        def reduce(*_):
            if param.name in self._task_flow.full_grad.keys():
                full_grad = self._task_flow.full_grad[param.name]
B
Baibaifan 已提交
529 530 531 532 533 534 535
                if not self._accumulate_grads:
                    full_grad.scale_(scale=self._world_size_scaling)
                # Only support sync allreduce current rank's layer now
                dist.all_reduce(
                    tensor=full_grad, group=self._group, use_calc_stream=True)
                dist.wait(
                    tensor=full_grad, group=self._group, use_calc_stream=True)
B
Baibaifan 已提交
536

B
Baibaifan 已提交
537 538 539 540 541 542 543 544 545 546 547 548 549 550
                start, end = self._param2buffer[param.name][self._rank]
                if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
                ).get_tensor()._is_initialized():
                    param.bw_storage = core.VarBase(
                        full_grad._slice(start, end)).detach().clone()
                    if self._offload:
                        param.bw_storage = _device2cpu(param.bw_storage, True)
                else:
                    if self._offload:
                        cpu_grad = _device2cpu(
                            core.VarBase(full_grad._slice(start, end))
                            .detach().clone(), True)
                        param.bw_storage = paddle.add(param.bw_storage,
                                                      cpu_grad)
B
Baibaifan 已提交
551
                    else:
B
Baibaifan 已提交
552 553 554 555 556 557 558
                        # param.bw_storage.add_(
                        #     core.VarBase(full_grad._slice(start, end))
                        #     .detach().clone())
                        param.bw_storage = paddle.add(
                            param.bw_storage,
                            core.VarBase(full_grad._slice(start, end)).detach(
                            ).clone())
B
Baibaifan 已提交
559 560 561 562 563 564 565 566 567 568
                param.clear_gradient(False)
                param._gradient_set_empty(False)
                tmp_var = self._task_flow.full_grad.pop(param.name)
                tmp_var._clear()

            if param.name in self._task_flow.full_param.keys():
                if param.status == "all":
                    param.use_count = 0
                    param._clear()
                    start, end = self._param2buffer[param.name][self._rank]
B
Baibaifan 已提交
569 570 571
                    param.fw_storage = core.VarBase(
                        self._task_flow.full_param[param.name]._slice(
                            start, end), param.name + "@slice").detach().clone()
B
Baibaifan 已提交
572 573 574 575
                    param.status = "part"
                    tmp_var = self._task_flow.full_param.pop(param.name)
                    tmp_var._clear()

B
Baibaifan 已提交
576 577 578
                    if self._offload:
                        param.fw_storage = _device2cpu(param.fw_storage, True)

B
Baibaifan 已提交
579 580
        return reduce

B
Baibaifan 已提交
581 582 583 584 585 586 587 588 589
    def _param2align(self, param):
        # CUDA alignment 256 bytes
        size = param._numel() * align[param.dtype]
        remaining = size % alignment[self._default_device]
        ali = 0 if remaining == 0 else alignment[
            self._default_device] - remaining
        align_ = ali // align[param.dtype]
        return align_

B
Baibaifan 已提交
590 591 592 593 594
    def _redefine_opt_step(self):
        params_slice_func = self._update_params_slice
        opt_step = self._optim.step

        def _opt_step(self):
B
Baibaifan 已提交
595
            if not self.update_scaler:
B
Baibaifan 已提交
596
                params_slice_func()
B
Baibaifan 已提交
597 598 599 600 601
            if self.offload:
                with device_guard(device="cpu"):
                    opt_step()
            else:
                opt_step()
B
Baibaifan 已提交
602

B
Baibaifan 已提交
603 604 605 606 607
        def _opt_minimize(self):
            raise RuntimeError(
                "optimizer.minimize() not support now, please use optimizer.step()"
            )

B
Baibaifan 已提交
608
        self._optim.step = MethodType(_opt_step, self._optim)
B
Baibaifan 已提交
609
        self._optim.minimize = MethodType(_opt_minimize, self._optim)
B
Baibaifan 已提交
610 611 612 613 614 615 616 617 618 619 620

    def _redefine_opt_clear(self):
        clear_func = self._clear_gradients

        def _opt_clear(self):
            clear_func()

        self._optim.clear_grad = MethodType(_opt_clear, self._optim)


def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank,
B
Baibaifan 已提交
621
                    group, sync_comm, offload, task_flow):
B
Baibaifan 已提交
622 623 624 625 626 627 628

    # Record layer's id
    layer_id = id(layer)
    use_calc, sync_wait = False, False

    if layer_id not in order_tracer.keys() or sync_comm:
        use_calc, sync_wait = True, True
B
Baibaifan 已提交
629 630

        # Whether to use calc stream
B
Baibaifan 已提交
631 632
        task_flow.use_calc[layer_id] = use_calc
    else:
B
Baibaifan 已提交
633
        # Whether to use calc stream
B
Baibaifan 已提交
634
        task_flow.use_calc[layer_id] = use_calc
B
Baibaifan 已提交
635 636 637
        # wait current layer params
        _wait_layer(trainable_params[layer_id], task_flow, group, use_calc,
                    offload)
B
Baibaifan 已提交
638 639 640 641

        if layer_id == order_tracer["layer"][-1]: return
        order_ = order_tracer[layer_id]
        layer_id = order_tracer["layer"][order_ + 1]
B
Baibaifan 已提交
642

B
Baibaifan 已提交
643
    _allgather_buffer(
B
Baibaifan 已提交
644
        trainable_params[layer_id],
B
Baibaifan 已提交
645 646 647
        group,
        use_calc_stream=use_calc,
        task_flow=task_flow,
B
Baibaifan 已提交
648 649 650
        sync_wait=sync_wait,
        offload=offload)

B
Baibaifan 已提交
651 652 653 654 655 656 657
    return


class ForwardPostHooks(PyLayer):
    @staticmethod
    def forward(ctx, inputs, layer, order_tracer, trainable_params,
                param2buffer, param2buffer_size, rank, group, sync_comm,
B
Baibaifan 已提交
658
                offload, task_flow):
B
Baibaifan 已提交
659 660

        layer_id = id(layer)
B
Baibaifan 已提交
661 662 663 664
        # release current layer full params
        _release_param(trainable_params[layer_id], param2buffer, rank,
                       task_flow, offload)

B
Baibaifan 已提交
665 666 667 668 669
        if layer_id not in order_tracer.keys():
            order_ = order_tracer["order"]
            order_tracer[layer_id] = order_
            order_tracer["order"] += 1
            order_tracer["layer"].append(layer_id)
B
Baibaifan 已提交
670 671

        #Record bw info 
B
Baibaifan 已提交
672 673 674 675 676 677 678
        ctx.order_tracer = order_tracer
        ctx.task_flow = task_flow
        ctx.group = group
        ctx.layer = layer
        ctx.sync_comm = sync_comm
        ctx.trainable_params = trainable_params
        ctx.param2buffer_size = param2buffer_size
B
Baibaifan 已提交
679
        ctx.offload = offload
B
Baibaifan 已提交
680 681 682 683 684 685 686 687 688 689 690 691 692

        return inputs

    @staticmethod
    def backward(ctx, *args):
        # Load context value
        order_tracer = ctx.order_tracer
        task_flow = ctx.task_flow
        group = ctx.group
        layer = ctx.layer
        trainable_params = ctx.trainable_params
        param2buffer_size = ctx.param2buffer_size
        sync_comm = ctx.sync_comm
B
Baibaifan 已提交
693
        offload = ctx.offload
B
Baibaifan 已提交
694 695
        layer_id = id(layer)
        use_calc, sync_wait = False, False
B
Baibaifan 已提交
696 697

        # Allgather params synchronization
B
Baibaifan 已提交
698 699 700
        if sync_comm:
            use_calc, sync_wait = True, True
            _allgather_buffer(
B
Baibaifan 已提交
701
                trainable_params[layer_id],
B
Baibaifan 已提交
702 703 704
                group,
                use_calc_stream=use_calc,
                task_flow=task_flow,
B
Baibaifan 已提交
705 706
                sync_wait=sync_wait,
                offload=offload)
B
Baibaifan 已提交
707
        else:
B
Baibaifan 已提交
708 709 710 711 712
            _wait_layer(trainable_params[layer_id], task_flow, group, use_calc,
                        offload)

        # Create params's grad
        _create_params_grad(trainable_params[layer_id], param2buffer_size,
B
Baibaifan 已提交
713
                            task_flow)
B
Baibaifan 已提交
714 715

        # Whether to use calc stream
B
Baibaifan 已提交
716 717 718 719
        task_flow.use_calc[layer_id] = use_calc
        if layer_id != order_tracer["layer"][0] and not sync_comm:
            layer_next_id = order_tracer["layer"][order_tracer[layer_id] - 1]
            _allgather_buffer(
B
Baibaifan 已提交
720
                trainable_params[layer_next_id],
B
Baibaifan 已提交
721 722 723
                group,
                use_calc_stream=use_calc,
                task_flow=task_flow,
B
Baibaifan 已提交
724 725
                sync_wait=sync_wait,
                offload=offload)
B
Baibaifan 已提交
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745

        return args


class TaskFlow:
    """
    Task flows, one way linked list for task acquisition.
    """

    def __init__(self,
                 full_param=dict(),
                 full_grad=dict(),
                 use_calc=dict(),
                 callback=None):
        self.full_param = full_param
        self.full_grad = full_grad
        self.use_calc = use_calc
        self.callback = callback


B
Baibaifan 已提交
746 747 748 749 750 751
def _release_param(trainable_params,
                   param2buffer,
                   rank,
                   task_flow,
                   offload=False):
    for param in trainable_params:
B
Baibaifan 已提交
752 753 754 755 756 757 758 759 760 761 762 763 764
        # async communicate share weight not clear
        param.use_count -= 1
        if param.use_count == 0:
            param._clear()
            if param.name in task_flow.full_param.keys():
                start, end = param2buffer[param.name][rank]
                with paddle.amp.auto_cast(enable=False):
                    param.fw_storage = core.VarBase(
                        task_flow.full_param[param.name]._slice(start, end),
                        param.name + "@slice").detach().clone()
                param.status = "part"
                tmp_var = task_flow.full_param.pop(param.name)
                tmp_var._clear()
B
Baibaifan 已提交
765 766 767

                if offload:
                    param.fw_storage = _device2cpu(param.fw_storage)
B
Baibaifan 已提交
768 769 770
    return


B
Baibaifan 已提交
771 772 773 774 775
def _wait_layer(trainable_params,
                task_flow,
                group,
                use_calc_stream,
                offload=False):
B
Baibaifan 已提交
776
    paddle.device.cuda.synchronize()
B
Baibaifan 已提交
777
    for param in trainable_params:
B
Baibaifan 已提交
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
        if param.status == "all":
            param.use_count += 1
            continue
        if param.name in task_flow.full_param.keys():
            full_param = task_flow.full_param[param.name]
            core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
                param)
            param.fw_storage._clear()
            param.fw_storage = None
            param.status = "all"
            param.use_count += 1
        else:
            _allgather_buffer(
                trainable_params,
                group,
B
Baibaifan 已提交
793 794 795 796
                use_calc_stream=True,
                task_flow=task_flow,
                sync_wait=True,
                offload=offload)
B
Baibaifan 已提交
797 798 799 800
            break
    return task_flow


B
Baibaifan 已提交
801
def _allgather_buffer(trainable_params,
B
Baibaifan 已提交
802 803 804
                      group,
                      use_calc_stream,
                      task_flow,
B
Baibaifan 已提交
805 806 807 808 809
                      sync_wait=False,
                      offload=False,
                      convert2cpu=False):

    for param in trainable_params:
B
Baibaifan 已提交
810 811 812
        if param.status == "all":
            param.use_count += 1
            continue
B
Baibaifan 已提交
813 814 815 816

        if offload:
            param.fw_storage = _cpu2device(param)

B
Baibaifan 已提交
817 818 819
        with paddle.amp.auto_cast(enable=False):
            full_param = _all_gather(
                param.fw_storage, group, use_calc_stream=use_calc_stream)
B
Baibaifan 已提交
820

B
Baibaifan 已提交
821
        # Allgather current layer in the 1st step synchronously
B
Baibaifan 已提交
822 823 824 825 826 827 828 829 830 831 832 833 834
        if sync_wait:
            with paddle.amp.auto_cast(enable=False):
                dist.wait(
                    tensor=full_param,
                    group=group,
                    use_calc_stream=use_calc_stream)
            core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
                param)
            param.fw_storage._clear()
            param.fw_storage = None
            param.status = "all"
            param.use_count += 1
        task_flow.full_param[param.name] = full_param
B
Baibaifan 已提交
835 836 837 838 839 840 841 842 843

        # parameter converts to cpu 
        if convert2cpu:
            p_name = param.name
            param = _device2cpu(param)
            tmp_var = task_flow.full_param.pop(p_name)
            tmp_var._clear()
            task_flow.full_param[p_name] = param

B
Baibaifan 已提交
844 845 846 847
    return task_flow


@paddle.no_grad()
B
Baibaifan 已提交
848 849
def _create_params_grad(trainable_params, param2buffer_size, task_flow):
    for param in trainable_params:
B
Baibaifan 已提交
850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869
        if param.name in task_flow.full_grad.keys():
            continue
        assert isinstance(param2buffer_size[param.name], int)
        temp_grad = paddle.zeros(
            [param2buffer_size[param.name]], dtype=param.dtype)
        param._copy_gradient_from(
            core.VarBase(temp_grad._slice(0, param._numel())))
        task_flow.full_grad[param.name] = temp_grad
    return task_flow


def _PartitionParam(param):
    if not hasattr(param, "fw_storage"):
        setattr(param, "fw_storage", None)
        setattr(param, "bw_storage", None)
        setattr(param, "status", "all")
        setattr(param, "use_count", 0)
    return param


B
Baibaifan 已提交
870 871 872 873 874 875
def _UnsliceParam(param):
    if not hasattr(param, "unslice"):
        setattr(param, "unslice", True)
    return param


B
Baibaifan 已提交
876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897
def _VarBaseWrapper(param):
    varbase = param.fw_storage
    tmp_param = ParamBase(
        shape=varbase.shape, dtype=varbase.dtype, name="slice@" + param.name)
    varbase._share_buffer_to(tmp_param)
    tmp_param.regularizer = param.regularizer
    tmp_param.optimize_attr['learning_rate'] = param.optimize_attr[
        'learning_rate']
    varbase._clear()
    return tmp_param


def _OptimizerWrapper(optimizer, offload, group, update_params_slice):
    if not hasattr(optimizer, "_optim"):
        setattr(optimizer, "_optim", optimizer)
        setattr(optimizer, "offload", offload)
        setattr(optimizer, "group", group)
        setattr(optimizer, "update_scaler", None)
        setattr(optimizer, "update_slice", update_params_slice)
    return optimizer


B
Baibaifan 已提交
898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914
def _device2cpu(trans_param, convert_dtype=False):
    if convert_dtype:
        trans_param = paddle.cast(trans_param, Type.fp32.value)
    tmp_p = trans_param.cpu()
    trans_param._clear()
    return tmp_p


def _cpu2device(param):
    tmp_p = param.fw_storage.cuda(DEV_ID)
    param.fw_storage._clear()
    if tmp_p.dtype == Type.fp32.value and param2dtype[
            param.name] == Type.fp16.value:
        tmp_p = paddle.cast(tmp_p, Type.fp16.value)
    return tmp_p


B
Baibaifan 已提交
915 916 917 918 919
def _current_layer_params(layer):
    return layer.parameters(
        include_sublayers=False) + list(layer.extra_parameters) if hasattr(
            layer, "extra_parameters") else layer.parameters(
                include_sublayers=False)