parallel.py 33.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
Y
Yan Xu 已提交
16
import numpy as np
17
import warnings
18
from collections import OrderedDict
S
ShenLiang 已提交
19 20
import itertools
import warnings
21
from contextlib import contextmanager
22

S
ShenLiang 已提交
23
import paddle
24
from paddle import _C_ops, _legacy_C_ops
25 26 27 28 29 30
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
31
from ..layers import collective
32
from paddle.fluid.dygraph import base as imperative_base
33 34 35 36 37 38
from paddle.fluid.framework import (
    ParamBase,
    _in_legacy_dygraph,
    _non_static_mode,
    in_dygraph_mode,
)
39

40
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
41 42 43 44

ParallelStrategy = core.ParallelStrategy


45
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
46
def prepare_context(strategy=None):
47 48 49
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
50 51 52 53 54 55 56 57
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
58 59 60
    assert (
        framework._non_static_mode() is True
    ), "dygraph.prepare_context should be used with dygraph mode."
61
    place = framework._current_expected_place()
62 63 64
    assert (
        place is not None
    ), "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
65 66 67
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
68 69
                core.NCCLParallelContext(strategy, place)
            )
70 71
        elif isinstance(place, core.XPUPlace):
            parallel_helper._set_parallel_ctx(
72 73
                core.BKCLParallelContext(strategy, place)
            )
74 75
        elif isinstance(place, core.NPUPlace):
            parallel_helper._set_parallel_ctx(
76 77
                core.HCCLParallelContext(strategy, place)
            )
78 79
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
80
            assert "Only support CUDAPlace or XPUPlace or NPUPlace for now."
81
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
82
    return strategy
83 84


85
class ParallelEnv:
86
    """
87
    .. note::
88 89
        This API is not recommended, if you need to get rank and world_size,
        it is recommended to use ``paddle.distributed.get_rank()`` and
90
        ``paddle.distributed.get_world_size()`` .
91

92
    This class is used to obtain the environment variables required for
93
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
94

95
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
96
    or ``paddle.distributed.spawn`` .
97 98 99 100

    Examples:
      .. code-block:: python

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
125 126
    """

127
    def __init__(self):
128 129
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
130
        self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", ""))
131

132
        # imperative only support one gpu or xpu
133 134
        if self._device_type != "":
            FLAGS_selected_custom_devices = 'FLAGS_selected_{}s'.format(
135 136 137 138 139
                self._device_type
            )
            selected_custom_devices = os.getenv(
                FLAGS_selected_custom_devices, "0"
            ).split(",")
140 141 142 143 144 145 146 147 148 149 150 151 152 153
            self._device_id = int(selected_custom_devices[0])
        else:
            if core.is_compiled_with_cuda():
                selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
                self._device_id = int(selected_gpus[0])
            elif core.is_compiled_with_xpu():
                selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
                self._device_id = int(selected_xpus[0])
            elif core.is_compiled_with_npu():
                selected_npus = os.getenv("FLAGS_selected_npus", "0").split(",")
                self._device_id = int(selected_npus[0])
            elif core.is_compiled_with_mlu():
                selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",")
                self._device_id = int(selected_mlus[0])
154

155 156 157
        self._trainer_endpoints = os.getenv(
            "PADDLE_TRAINER_ENDPOINTS", ""
        ).split(",")
158
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
159
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
160 161 162 163 164 165
        assert (
            self._nrings > 0
        ), "nccl_nrings must be an integer greater than 0."
        assert (
            self._nrings < 9
        ), "nccl_nrings should be less than 9, which is enough in most scenarios."
166 167

    @property
168
    def rank(self):
169
        """
170
        Rank of current trainer.
171

172
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
173 174 175 176

        Examples:
          .. code-block:: python

177 178
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
179

180 181 182
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
183
        """
184
        return self._rank
185 186

    @property
187
    def world_size(self):
188
        """
189
        The number of trainers (number of processes participating in current job).
190

191
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
192 193 194 195

        Examples:
          .. code-block:: python

196 197
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
198

199 200 201
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
202
        """
203
        return self._world_size
204 205

    @property
206
    def device_id(self):
207 208 209
        """
        The ID of selected GPU card for parallel training.

210
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
211 212 213 214 215

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
216
            import paddle.distributed as dist
217

218 219
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
220 221
            # The device id are 1
        """
222
        return self._device_id
223

224 225 226 227 228 229 230 231 232 233
    @property
    def device_type(self):
        """
        The type of custom device for parallel training.

        Its value is equal to the value of the environment variable ``PADDLE_XCCL_BACKEND`` . The default value is None.

        """
        return self._device_type

234 235
    @property
    def current_endpoint(self):
236 237 238
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

239
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
240 241 242

        Examples:
          .. code-block:: python
243

244
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
245
            import paddle.distributed as dist
246

247
            env = dist.ParallelEnv()
248 249 250
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
251
        return self._current_endpoint
252 253 254

    @property
    def trainer_endpoints(self):
255
        """
256
        The endpoints of all trainer nodes in the task,
257 258
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

259
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
260 261 262 263 264

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
265
            import paddle.distributed as dist
266

267
            env = dist.ParallelEnv()
268 269 270
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
271 272
        return self._trainer_endpoints

273 274 275 276 277 278 279 280 281 282 283 284
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
285

286 287 288 289 290 291
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

292 293 294 295 296
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

297

298 299 300 301 302 303
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


304 305 306 307 308 309 310 311 312 313 314
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
315

316 317 318 319 320 321 322
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
323 324
                nn.reshape(x=g_var, shape=[np.prod(g_var.shape)])
            )
325 326
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
327 328
            [coalesced_grad, grad_vars, g_var_shapes]
        )
329 330 331 332 333 334
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
335 336 337 338 339 340
    framework._dygraph_tracer().trace_op(
        type="reshape2",
        inputs={'X': x},
        outputs={'Out': x, 'XShape': x_shape},
        attrs={'shape': shape},
    )
341 342 343 344


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
345
    if _in_legacy_dygraph():
346 347 348 349 350
        for (
            coalesced_grad,
            origin_grad_vars,
            grad_shapes,
        ) in coalesced_grads_and_grad_vars:
351 352 353 354 355
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            framework._dygraph_tracer().trace_op(
                type='split',
                inputs={'X': coalesced_grad},
                outputs={'Out': origin_grad_vars},
356 357
                attrs={'sections': grad_var_len, 'axis': 0},
            )
358 359 360 361
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                _reshape_inplace(x=g_var, shape=g_shape)
                assert g_var.shape == g_shape
    elif in_dygraph_mode():
362 363 364 365 366
        for (
            coalesced_grad,
            origin_grad_vars,
            grad_shapes,
        ) in coalesced_grads_and_grad_vars:
367 368 369 370
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            attrs = ()
            attrs += ('sections', grad_var_len)
            attrs += ('axis', 0)
371
            _legacy_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
372 373 374
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                g_var.reshape_(shape=g_shape)
                assert g_var.shape == g_shape
375 376 377


def scale_loss(loss):
378
    # TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
379 380 381 382
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
383 384
        np.array([ParallelEnv().world_size]).astype("float32")
    )
385 386 387 388 389
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


390 391
@imperative_base.no_grad
@framework.dygraph_only
392
def build_groups(vars, group_size):
393 394 395 396 397 398 399 400 401 402
    group_idx = 0
    memory_counter = 0
    var_groups = OrderedDict()
    dtype = vars[0].dtype

    for var in vars:
        bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
        if memory_counter < group_size and dtype == var.dtype:
            memory_counter += bytes
        else:
403
            memory_counter = bytes
404 405 406 407 408 409 410 411
            dtype = var.dtype
            group_idx += 1
        var_groups.setdefault(group_idx, []).append(var)
    return _coalesce_tensors(var_groups)


@imperative_base.no_grad
@framework.dygraph_only
412 413 414
def sync_params_buffers(
    model, comm_group=None, src_rank=0, is_model_parallel=False
):
415
    model_vars = []
416
    for _, param in model._obtain_parameters_buffers().items():
417 418
        if not isinstance(param, (core.VarBase, core.eager.Tensor)):
            raise TypeError(
419 420 421
                "The data type of '%s' must be Varbase or eager.Tensor"
                % param.name
            )
422

423
        # is_distributed param not need to sync when in mp mode
424
        if isinstance(param, (ParamBase, core.eager.Tensor)):
425 426 427
            if is_model_parallel and param.is_distributed:
                continue

428
            # NOTE(shenliang03): Support situations that do not require synchronization parameters,
429 430
            # such as moe's expert parameters
            if getattr(param, "no_sync", False):
S
ShenLiang 已提交
431
                continue
432 433
        if param.type == core.VarDesc.VarType.VOCAB:
            continue
434 435 436 437 438 439

        model_vars.append(param.detach())
    if len(model_vars) == 0:
        return

    # group size is 128M
440
    coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
441 442

    for coalesced_var, _, _ in coalesced_vars:
443 444 445
        paddle.distributed.broadcast(
            coalesced_var, src=src_rank, group=comm_group, sync_op=True
        )
446 447 448 449 450 451 452

    for coalesced_var, origin_vars, var_shapes in coalesced_vars:
        var_len = [np.prod(v_shape) for v_shape in var_shapes]
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_var},
            outputs={'Out': origin_vars},
453 454
            attrs={'sections': var_len, 'axis': 0},
        )
455 456


457
class DataParallel(layers.Layer):
C
chengduo 已提交
458
    """
459
    Run the dygraph module with data parallelism.
C
chengduo 已提交
460

461
    Currently, DataParallel class only supports to run the dynamic graph
462 463
    with multi-process.

464 465 466 467 468
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
469

470
    2. start by ``paddle.distributed.launch`` module, for example:
471

472
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
473 474

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
475

476 477
    Args:
        layers(Layer): The module that should be executed by data parallel.
478
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
479
            contains environment configuration related to parallel execution. Default: None.
480 481
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer
                                          parameters' gradient which is the input of communication
482
                                          calling(e.g NCCLAllReduce). Default: 25.
483
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
484
                                         calling. Making the last communication buffer size small is useful to
485
                                         improve performance. Default: 1.
486
        find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
487 488 489 490 491 492
                                                all tensors in the return value of the wrapped model's
                                                forward function. For parameters not involved in loss
                                                calculation, their gradients will be marked as ready in
                                                advance to prepare reduce. Please note that all forward
                                                outputs derived from the wrapped model parameters must
                                                participate in the calculation of loss and subsequent
493
                                                gradient calculations. If not, serious error will occur.
494
                                                Note that setting the find_unused_parameters to True
495
                                                will affect computing performance. Therefore, if all parameters
496
                                                are sure to participate in the loss calculation and the
497
                                                autograd graph construction, please set it False. Default: False.
498

499 500 501
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
502
    Examples:
503

C
chengduo 已提交
504
        .. code-block:: python
505 506
            :name: dp-example

507
            # required: distributed
508 509 510 511 512 513 514
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
515
                    super().__init__()
516 517
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
518

519 520 521 522
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
523
                # 1. initialize parallel environment
524 525
                dist.init_parallel_env()

526
                # 2. create data parallel layer & optimizer
527 528 529 530 531 532 533
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

534
                # 3. run layer
535 536 537 538
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
539

540 541 542 543 544 545 546 547 548 549
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
550 551 552


    .. note::
553 554 555
        ``PyLayer`` is not supported in DataParallel. To solve problems of this kind,
        it's recommended to skip gradient synchronization among multiple cards by 'no_sync',
        and manually implement 'all_reduce' before model optimization. There is an example
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
        showing specific implemetation processing.

    Examples:

        .. code-block:: python
            :name: dp-pylayer-example

            # required: distributed
            import numpy
            import paddle
            import paddle.distributed as dist
            from paddle.autograd import PyLayer
            from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad

            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
585
                    super().__init__()
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
                    self.linear = paddle.nn.Linear(2, 2)

                def forward(self, inputs):
                    inputs = cus_tanh.apply(inputs)
                    return self.linear(inputs)

            if __name__ == '__main__':
                dist.init_parallel_env()

                model = SimpleNet()
                model = paddle.DataParallel(model)
                opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

                for step in range(10):
                    x_data = numpy.random.randn(2, 2).astype(numpy.float32)
                    x = paddle.to_tensor(x_data)
                    x.stop_gradient = False

                    # step 1 : skip gradient synchronization by 'no_sync'
                    with model.no_sync():
                        y_pred = model(x)
                        loss = y_pred.mean()
                        loss.backward()

                    # step 2 : fuse + allreduce manually before optimization
                    fused_allreduce_gradients(list(model.parameters()), None)

                    opt.step()
                    opt.clear_grad()

C
chengduo 已提交
616 617
    """

618 619 620 621 622 623 624 625 626
    def __init__(
        self,
        layers,
        strategy=None,
        comm_buffer_size=25,
        last_comm_buffer_size=1,
        find_unused_parameters=False,
        group=None,
    ):
627
        super().__init__(layers.full_name() + "_data_parallel")
628 629 630 631

        assert (
            _non_static_mode()
        ), "It's not supported to construct DataParallel in static mode."
632

633
        self._layers = layers
634
        self.find_unused_parameters = find_unused_parameters
635
        self.grad_need_sync = True
636
        self.group = group
637 638 639
        self.var_dtype = (
            core.eager.Tensor if in_dygraph_mode() else core.VarBase
        )
640

641 642
        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
        # It just stores some environment variables, which can be constructed by
643 644 645 646 647
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
648
            self._strategy = _build_default_parallel_strategy()
649

650
        if self._strategy.nranks > 1:
651
            # check the environment
652 653 654 655
            assert parallel_helper.__parallel_ctx__clz__ is not None, (
                "ParallelContext must be initialized before. You should use init_parallel_env() before"
                "constructing the DataParallel."
            )
656

657
            if in_dygraph_mode():
658 659 660 661 662
                self.group = (
                    paddle.distributed.collective._get_default_group()
                    if self.group is None
                    else self.group
                )
663

664 665 666
                assert isinstance(
                    self.group, paddle.distributed.collective.Group
                ), "ProcessGroup must be an instance of Group in DataParallel."
667

668
            # sync buffer and params
669
            # TODO(liuyuhui) Currently not support xpu. xpu is
670 671
            # still broadcasting parameters when calling layer
            if not paddle.is_compiled_with_xpu():
672
                sync_params_buffers(self._layers)
673

674
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
675 676 677
            # NOTE(shenliang03): We can set environment variables to control
            # the size of the group, Default: 1MB. The role of this small group is:
            # when the last group allreduce, the overlap cannot work. Making the
678
            # the last group small is useful to improve performance.
679 680 681
            self.last_comm_buffer_size = int(
                last_comm_buffer_size * 1024 * 1024
            )
682 683
            self.init_reducer()
        else:
684 685 686 687 688 689
            warnings.warn(
                "The program will return to single-card operation. "
                "Please check 1, whether you use spawn or fleetrun "
                "to start the program. 2, Whether it is a multi-card "
                "program. 3, Is the current environment multi-card."
            )
690 691 692 693 694 695 696 697 698

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
699
                if not isinstance(param, self.var_dtype):
700 701 702 703
                    raise TypeError(
                        "The data type of '%s' must be '%s'"
                        % (param.name, self.var_dtype)
                    )
704 705 706
                if param.trainable:
                    layers_param.append((sublayer, param))

707 708 709 710 711 712
        trainable_parameters = list(
            filter(
                lambda x: not getattr(x, "no_sync", False),
                [param for _, param in layers_param],
            )
        )
713

714 715
        assert len(trainable_parameters) > 0, (
            "This model does not have any parameters to train, and "
716
            "does not need to use DataParallel"
717
        )
718

719 720 721
        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
722
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
723
        def check_layer_sparse(sublayer):
724 725
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
726
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding
727
            # is removed in the future, the check will also be removed here.
728
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
729 730 731 732 733 734 735
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

736
        if in_dygraph_mode():
737
            self.group_indices = core.eager_assign_group_by_size(
738 739 740 741
                trainable_parameters,
                is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size],
            )
742 743

            self._reducer = core.EagerReducer(
744 745 746 747
                trainable_parameters,
                list(reversed(self.group_indices)),
                is_sparse_gradient,
                self.group.process_group,
748
                [self.last_comm_buffer_size, self.comm_buffer_size],
749 750
                self.find_unused_parameters,
            )
751
        elif _in_legacy_dygraph():
752
            self.group_indices = core.assign_group_by_size(
753 754 755 756
                trainable_parameters,
                is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size],
            )
757

758
            self._reducer = core.Reducer(
759 760 761 762
                trainable_parameters,
                list(reversed(self.group_indices)),
                is_sparse_gradient,
                parallel_helper.__parallel_ctx__clz__,
763
                [self.last_comm_buffer_size, self.comm_buffer_size],
764 765
                self.find_unused_parameters,
            )
766 767

    def _find_varbase(self, obj):
768
        var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase
769
        if isinstance(obj, var_type):
770 771 772 773 774 775
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
776

777 778 779
    @contextmanager
    def no_sync(self):
        """
780 781
        A context manager to stop gradient synchronization. Within no_sync(),
        gradients of parameters will only be accumulated on model and not
782 783 784 785 786 787 788 789 790 791 792 793
        synchronized util the first forward-backward out of this context.

        Examples:
            .. code-block:: python

                # required: distributed
                import paddle
                import paddle.nn as nn
                import paddle.distributed as dist

                class SimpleNet(nn.Layer):
                    def __init__(self):
794
                        super().__init__()
795
                        self._linear = nn.Linear(10, 1)
796

797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
                    def forward(self, x):
                        return self._linear(x)

                dist.init_parallel_env()
                model = SimpleNet()
                dp_model = paddle.DataParallel(model)

                inputs_1 = paddle.randn([10, 10], 'float32')
                inputs_2 = paddle.ones([10, 10], 'float32')

                with dp_model.no_sync():
                    # gradients will not be synchronized
                    dp_model(inputs_1).backward()

                # synchronization happens here
                dp_model(inputs_2).backward()

        """
        tmp_grad_need_sync = self.grad_need_sync
        self.grad_need_sync = False
        try:
            yield
        finally:
            self.grad_need_sync = tmp_grad_need_sync

822
    def forward(self, *inputs, **kwargs):
823
        outputs = self._layers(*inputs, **kwargs)
824 825 826 827 828
        if (
            self._strategy.nranks > 1
            and framework._dygraph_tracer()._has_grad
            and self.grad_need_sync
        ):
829
            self._reducer.prepare_for_backward(
830
                list(self._find_varbase(outputs))
831
            )
832
        return outputs
Y
Yan Xu 已提交
833

834 835 836
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore."
    )
Y
Yan Xu 已提交
837
    def scale_loss(self, loss):
C
chengduo 已提交
838
        """
839
        Deprecated method, now ``scale_loss`` is an empty method,
840
        keep this method just for compatibility.
C
chengduo 已提交
841
        """
Y
Yan Xu 已提交
842 843
        return loss

844 845 846
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore."
    )
Y
Yan Xu 已提交
847
    def apply_collective_grads(self):
C
chengduo 已提交
848
        """
849
        Deprecated method, now ``apply_collective_grads`` is an empty method,
850
        keep this method just for compatibility.
C
chengduo 已提交
851
        """
852
        return
853

854 855 856 857 858 859
    def state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
    ):
860
        '''
861
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
862 863

        Parameters:
864 865
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
866 867

        Retruns:
868
            dict: a dict contains all the parameters and persistable buffers.
869 870 871 872

        Examples:
            .. code-block:: python

873 874 875 876 877 878 879
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
880

881 882
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
883 884 885 886 887 888

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
889 890
            structured_name_prefix=structured_name_prefix,
        )
891

892
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
893
    def set_state_dict(self, state_dict, use_structured_name=True):
894
        '''
895
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
896 897

        Parameters:
898
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
899
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
900 901 902 903 904 905 906
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

907 908
                import paddle
                import paddle.distributed as dist
909

910
                dist.init_parallel_env()
911

912
                emb = paddle.nn.Embedding(10, 10)
913
                emb = fluid.dygraph.DataParallel(emb)
914

915
                state_dict = emb.state_dict()
916
                paddle.save(state_dict, "paddle_dy.pdparams")
917

918
                para_state_dict = paddle.load("paddle_dy.pdparams")
919
                emb.set_state_dict(para_state_dict)
920 921 922

        '''

923 924 925
        self._layers.set_state_dict(
            state_dict, use_structured_name=use_structured_name
        )
926 927 928 929

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict