parallel.py 33.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
Y
Yan Xu 已提交
16
import numpy as np
17
import warnings
18
from collections import OrderedDict
S
ShenLiang 已提交
19 20
import itertools
import warnings
21
from contextlib import contextmanager
22

S
ShenLiang 已提交
23
import paddle
24
from paddle import _C_ops, _legacy_C_ops
25 26 27 28 29 30
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
31
from ..layers import collective
32
from paddle.fluid.dygraph import base as imperative_base
33
from paddle.fluid.framework import ParamBase, _in_legacy_dygraph, _non_static_mode, in_dygraph_mode
34

35
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
36 37 38 39

ParallelStrategy = core.ParallelStrategy


40
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
41
def prepare_context(strategy=None):
42 43 44
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
45 46 47 48 49 50 51 52
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
J
Jiabin Yang 已提交
53
    assert framework._non_static_mode() is True, \
54
        "dygraph.prepare_context should be used with dygraph mode."
55
    place = framework._current_expected_place()
C
chengduo 已提交
56
    assert place is not None, \
57
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
58 59 60 61
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
62 63 64
        elif isinstance(place, core.XPUPlace):
            parallel_helper._set_parallel_ctx(
                core.BKCLParallelContext(strategy, place))
65 66 67
        elif isinstance(place, core.NPUPlace):
            parallel_helper._set_parallel_ctx(
                core.HCCLParallelContext(strategy, place))
68 69
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
70
            assert ("Only support CUDAPlace or XPUPlace or NPUPlace for now.")
71
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
72
    return strategy
73 74


75 76
class ParallelEnv(object):
    """
77
    .. note::
78 79
        This API is not recommended, if you need to get rank and world_size,
        it is recommended to use ``paddle.distributed.get_rank()`` and
80
        ``paddle.distributed.get_world_size()`` .
81

82
    This class is used to obtain the environment variables required for
83
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
84

85
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
86
    or ``paddle.distributed.spawn`` .
87 88 89 90

    Examples:
      .. code-block:: python

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
115 116
    """

117
    def __init__(self):
118 119
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
120
        self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", ""))
121

122
        # imperative only support one gpu or xpu
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
        if self._device_type != "":
            FLAGS_selected_custom_devices = 'FLAGS_selected_{}s'.format(
                self._device_type)
            selected_custom_devices = os.getenv(FLAGS_selected_custom_devices,
                                                "0").split(",")
            self._device_id = int(selected_custom_devices[0])
        else:
            if core.is_compiled_with_cuda():
                selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
                self._device_id = int(selected_gpus[0])
            elif core.is_compiled_with_xpu():
                selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
                self._device_id = int(selected_xpus[0])
            elif core.is_compiled_with_npu():
                selected_npus = os.getenv("FLAGS_selected_npus", "0").split(",")
                self._device_id = int(selected_npus[0])
            elif core.is_compiled_with_mlu():
                selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",")
                self._device_id = int(selected_mlus[0])
142

143 144 145
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
146 147 148 149 150
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert self._nrings > 0, \
            "nccl_nrings must be an integer greater than 0."
        assert self._nrings < 9, \
            "nccl_nrings should be less than 9, which is enough in most scenarios."
151 152

    @property
153
    def rank(self):
154
        """
155
        Rank of current trainer.
156

157
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
158 159 160 161

        Examples:
          .. code-block:: python

162 163
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
164

165 166 167
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
168
        """
169
        return self._rank
170 171

    @property
172
    def world_size(self):
173
        """
174
        The number of trainers (number of processes participating in current job).
175

176
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
177 178 179 180

        Examples:
          .. code-block:: python

181 182
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
183

184 185 186
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
187
        """
188
        return self._world_size
189 190

    @property
191
    def device_id(self):
192 193 194
        """
        The ID of selected GPU card for parallel training.

195
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
196 197 198 199 200

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
201
            import paddle.distributed as dist
202

203 204
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
205 206
            # The device id are 1
        """
207
        return self._device_id
208

209 210 211 212 213 214 215 216 217 218
    @property
    def device_type(self):
        """
        The type of custom device for parallel training.

        Its value is equal to the value of the environment variable ``PADDLE_XCCL_BACKEND`` . The default value is None.

        """
        return self._device_type

219 220
    @property
    def current_endpoint(self):
221 222 223
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

224
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
225 226 227

        Examples:
          .. code-block:: python
228

229
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
230
            import paddle.distributed as dist
231

232
            env = dist.ParallelEnv()
233 234 235
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
236
        return self._current_endpoint
237 238 239

    @property
    def trainer_endpoints(self):
240
        """
241
        The endpoints of all trainer nodes in the task,
242 243
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

244
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
245 246 247 248 249

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
250
            import paddle.distributed as dist
251

252
            env = dist.ParallelEnv()
253 254 255
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
256 257
        return self._trainer_endpoints

258 259 260 261 262 263 264 265 266 267 268 269
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
270

271 272 273 274 275 276
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

277 278 279 280 281
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

282

283 284 285 286 287 288
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
307
                nn.reshape(x=g_var, shape=[np.prod(g_var.shape)]))
308 309 310 311 312 313 314 315 316
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
317 318 319 320 321 322 323
    framework._dygraph_tracer().trace_op(type="reshape2",
                                         inputs={'X': x},
                                         outputs={
                                             'Out': x,
                                             'XShape': x_shape
                                         },
                                         attrs={'shape': shape})
324 325 326 327


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
328 329 330 331 332 333 334
    if _in_legacy_dygraph():
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            framework._dygraph_tracer().trace_op(
                type='split',
                inputs={'X': coalesced_grad},
                outputs={'Out': origin_grad_vars},
335 336 337 338
                attrs={
                    'sections': grad_var_len,
                    'axis': 0
                })
339 340 341 342 343 344 345 346 347
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                _reshape_inplace(x=g_var, shape=g_shape)
                assert g_var.shape == g_shape
    elif in_dygraph_mode():
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            attrs = ()
            attrs += ('sections', grad_var_len)
            attrs += ('axis', 0)
348
            _legacy_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
349 350 351
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                g_var.reshape_(shape=g_shape)
                assert g_var.shape == g_shape
352 353 354


def scale_loss(loss):
355
    # TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
356 357 358 359 360 361 362 363 364 365
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


366 367
@imperative_base.no_grad
@framework.dygraph_only
368
def build_groups(vars, group_size):
369 370 371 372 373 374 375 376 377 378
    group_idx = 0
    memory_counter = 0
    var_groups = OrderedDict()
    dtype = vars[0].dtype

    for var in vars:
        bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
        if memory_counter < group_size and dtype == var.dtype:
            memory_counter += bytes
        else:
379
            memory_counter = bytes
380 381 382 383 384 385 386 387 388 389 390 391 392
            dtype = var.dtype
            group_idx += 1
        var_groups.setdefault(group_idx, []).append(var)
    return _coalesce_tensors(var_groups)


@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(model,
                        comm_group=None,
                        src_rank=0,
                        is_model_parallel=False):
    model_vars = []
393
    for _, param in model._obtain_parameters_buffers().items():
394 395 396 397
        if not isinstance(param, (core.VarBase, core.eager.Tensor)):
            raise TypeError(
                "The data type of '%s' must be Varbase or eager.Tensor" %
                param.name)
398

399
        # is_distributed param not need to sync when in mp mode
400
        if isinstance(param, (ParamBase, core.eager.Tensor)):
401 402 403
            if is_model_parallel and param.is_distributed:
                continue

404
            # NOTE(shenliang03): Support situations that do not require synchronization parameters,
405 406
            # such as moe's expert parameters
            if getattr(param, "no_sync", False):
S
ShenLiang 已提交
407
                continue
408 409
        if param.type == core.VarDesc.VarType.VOCAB:
            continue
410 411 412 413 414 415

        model_vars.append(param.detach())
    if len(model_vars) == 0:
        return

    # group size is 128M
416
    coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
417 418

    for coalesced_var, _, _ in coalesced_vars:
419 420 421
        paddle.distributed.broadcast(coalesced_var,
                                     src=src_rank,
                                     group=comm_group,
422
                                     sync_op=True)
423 424 425 426 427 428 429

    for coalesced_var, origin_vars, var_shapes in coalesced_vars:
        var_len = [np.prod(v_shape) for v_shape in var_shapes]
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_var},
            outputs={'Out': origin_vars},
430 431 432 433
            attrs={
                'sections': var_len,
                'axis': 0
            })
434 435


436
class DataParallel(layers.Layer):
C
chengduo 已提交
437
    """
438
    Run the dygraph module with data parallelism.
C
chengduo 已提交
439

440
    Currently, DataParallel class only supports to run the dynamic graph
441 442
    with multi-process.

443 444 445 446 447
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
448

449
    2. start by ``paddle.distributed.launch`` module, for example:
450

451
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
452 453

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
454

455 456
    Args:
        layers(Layer): The module that should be executed by data parallel.
457
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
458
            contains environment configuration related to parallel execution. Default: None.
459 460
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer
                                          parameters' gradient which is the input of communication
461
                                          calling(e.g NCCLAllReduce). Default: 25.
462
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
463
                                         calling. Making the last communication buffer size small is useful to
464
                                         improve performance. Default: 1.
465
        find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
466 467 468 469 470 471
                                                all tensors in the return value of the wrapped model's
                                                forward function. For parameters not involved in loss
                                                calculation, their gradients will be marked as ready in
                                                advance to prepare reduce. Please note that all forward
                                                outputs derived from the wrapped model parameters must
                                                participate in the calculation of loss and subsequent
472
                                                gradient calculations. If not, serious error will occur.
473
                                                Note that setting the find_unused_parameters to True
474
                                                will affect computing performance. Therefore, if all parameters
475
                                                are sure to participate in the loss calculation and the
476
                                                autograd graph construction, please set it False. Default: False.
477

478 479 480
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
481
    Examples:
482

C
chengduo 已提交
483
        .. code-block:: python
484 485
            :name: dp-example

486
            # required: distributed
487 488 489 490 491 492 493 494 495 496
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
497

498 499 500 501
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
502
                # 1. initialize parallel environment
503 504
                dist.init_parallel_env()

505
                # 2. create data parallel layer & optimizer
506 507 508 509 510 511 512
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

513
                # 3. run layer
514 515 516 517
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
518

519 520 521 522 523 524 525 526 527 528
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
529 530 531


    .. note::
532 533 534
        ``PyLayer`` is not supported in DataParallel. To solve problems of this kind,
        it's recommended to skip gradient synchronization among multiple cards by 'no_sync',
        and manually implement 'all_reduce' before model optimization. There is an example
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
        showing specific implemetation processing.

    Examples:

        .. code-block:: python
            :name: dp-pylayer-example

            # required: distributed
            import numpy
            import paddle
            import paddle.distributed as dist
            from paddle.autograd import PyLayer
            from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad

            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super(SimpleNet, self).__init__()
                    self.linear = paddle.nn.Linear(2, 2)

                def forward(self, inputs):
                    inputs = cus_tanh.apply(inputs)
                    return self.linear(inputs)

            if __name__ == '__main__':
                dist.init_parallel_env()

                model = SimpleNet()
                model = paddle.DataParallel(model)
                opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

                for step in range(10):
                    x_data = numpy.random.randn(2, 2).astype(numpy.float32)
                    x = paddle.to_tensor(x_data)
                    x.stop_gradient = False

                    # step 1 : skip gradient synchronization by 'no_sync'
                    with model.no_sync():
                        y_pred = model(x)
                        loss = y_pred.mean()
                        loss.backward()

                    # step 2 : fuse + allreduce manually before optimization
                    fused_allreduce_gradients(list(model.parameters()), None)

                    opt.step()
                    opt.clear_grad()

C
chengduo 已提交
595 596
    """

597 598 599
    def __init__(self,
                 layers,
                 strategy=None,
600
                 comm_buffer_size=25,
601
                 last_comm_buffer_size=1,
602
                 find_unused_parameters=False,
603
                 group=None):
604 605
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
606

607 608 609
        assert _non_static_mode(), \
            "It's not supported to construct DataParallel in static mode."

610
        self._layers = layers
611
        self.find_unused_parameters = find_unused_parameters
612
        self.grad_need_sync = True
613
        self.group = group
614
        self.var_dtype = core.eager.Tensor if in_dygraph_mode(
J
Jiabin Yang 已提交
615
        ) else core.VarBase
616

617 618
        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
        # It just stores some environment variables, which can be constructed by
619 620 621 622 623
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
624
            self._strategy = _build_default_parallel_strategy()
625

626
        if self._strategy.nranks > 1:
627 628 629 630 631
            # check the environment
            assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

632 633 634 635 636 637
            if in_dygraph_mode():
                self.group = paddle.distributed.collective._get_default_group(
                ) if self.group is None else self.group

                assert isinstance(self.group, paddle.distributed.collective.Group), \
                    "ProcessGroup must be an instance of Group in DataParallel."
638

639
            # sync buffer and params
640
            # TODO(liuyuhui) Currently not support xpu. xpu is
641 642
            # still broadcasting parameters when calling layer
            if not paddle.is_compiled_with_xpu():
643
                sync_params_buffers(self._layers)
644

645
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
646 647 648
            # NOTE(shenliang03): We can set environment variables to control
            # the size of the group, Default: 1MB. The role of this small group is:
            # when the last group allreduce, the overlap cannot work. Making the
649
            # the last group small is useful to improve performance.
650 651
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
652 653
            self.init_reducer()
        else:
S
ShenLiang 已提交
654 655
            warnings.warn("The program will return to single-card operation. "
                          "Please check 1, whether you use spawn or fleetrun "
656 657
                          "to start the program. 2, Whether it is a multi-card "
                          "program. 3, Is the current environment multi-card.")
658 659 660 661 662 663 664 665 666

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
667 668 669
                if not isinstance(param, self.var_dtype):
                    raise TypeError("The data type of '%s' must be '%s'" %
                                    (param.name, self.var_dtype))
670 671 672 673 674
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

675 676 677 678
        assert len(trainable_parameters) > 0, \
            "This model does not have any parameters to train, and " \
            "does not need to use DataParallel"

679 680 681
        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
682
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
683
        def check_layer_sparse(sublayer):
684 685
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
686
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding
687
            # is removed in the future, the check will also be removed here.
688
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
689 690 691 692 693 694 695
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

696
        if in_dygraph_mode():
697 698 699 700 701
            self.group_indices = core.eager_assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])

            self._reducer = core.EagerReducer(
702 703
                trainable_parameters, list(reversed(self.group_indices)),
                is_sparse_gradient, self.group.process_group,
704 705
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
706
        elif _in_legacy_dygraph():
707 708 709
            self.group_indices = core.assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])
710

711
            self._reducer = core.Reducer(
712 713
                trainable_parameters, list(reversed(self.group_indices)),
                is_sparse_gradient, parallel_helper.__parallel_ctx__clz__,
714 715
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
716 717

    def _find_varbase(self, obj):
718
        var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase
719
        if isinstance(obj, var_type):
720 721 722 723 724 725
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
726

727 728 729
    @contextmanager
    def no_sync(self):
        """
730 731
        A context manager to stop gradient synchronization. Within no_sync(),
        gradients of parameters will only be accumulated on model and not
732 733 734 735 736 737 738 739 740 741 742 743 744 745
        synchronized util the first forward-backward out of this context.

        Examples:
            .. code-block:: python

                # required: distributed
                import paddle
                import paddle.nn as nn
                import paddle.distributed as dist

                class SimpleNet(nn.Layer):
                    def __init__(self):
                        super(SimpleNet, self).__init__()
                        self._linear = nn.Linear(10, 1)
746

747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771
                    def forward(self, x):
                        return self._linear(x)

                dist.init_parallel_env()
                model = SimpleNet()
                dp_model = paddle.DataParallel(model)

                inputs_1 = paddle.randn([10, 10], 'float32')
                inputs_2 = paddle.ones([10, 10], 'float32')

                with dp_model.no_sync():
                    # gradients will not be synchronized
                    dp_model(inputs_1).backward()

                # synchronization happens here
                dp_model(inputs_2).backward()

        """
        tmp_grad_need_sync = self.grad_need_sync
        self.grad_need_sync = False
        try:
            yield
        finally:
            self.grad_need_sync = tmp_grad_need_sync

772
    def forward(self, *inputs, **kwargs):
773
        outputs = self._layers(*inputs, **kwargs)
774 775
        if self._strategy.nranks > 1 and framework._dygraph_tracer(
        )._has_grad and self.grad_need_sync:
776 777
            self._reducer.prepare_for_backward(list(
                self._find_varbase(outputs)))
778
        return outputs
Y
Yan Xu 已提交
779

780 781
    @deprecated(since="2.0.0",
                reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
782
    def scale_loss(self, loss):
C
chengduo 已提交
783
        """
784
        Deprecated method, now ``scale_loss`` is an empty method,
785
        keep this method just for compatibility.
C
chengduo 已提交
786
        """
Y
Yan Xu 已提交
787 788
        return loss

789 790
    @deprecated(since="2.0.0",
                reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
791
    def apply_collective_grads(self):
C
chengduo 已提交
792
        """
793
        Deprecated method, now ``apply_collective_grads`` is an empty method,
794
        keep this method just for compatibility.
C
chengduo 已提交
795
        """
796
        return
797 798 799 800 801 802

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
803
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
804 805

        Parameters:
806 807
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
808 809

        Retruns:
810
            dict: a dict contains all the parameters and persistable buffers.
811 812 813 814

        Examples:
            .. code-block:: python

815 816 817 818 819 820 821
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
822

823 824
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
825 826 827 828 829 830 831 832

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

833
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
834
    def set_state_dict(self, state_dict, use_structured_name=True):
835
        '''
836
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
837 838

        Parameters:
839
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
840
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
841 842 843 844 845 846 847
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

848 849
                import paddle
                import paddle.distributed as dist
850

851
                dist.init_parallel_env()
852

853
                emb = paddle.nn.Embedding(10, 10)
854
                emb = fluid.dygraph.DataParallel(emb)
855

856
                state_dict = emb.state_dict()
857
                paddle.save(state_dict, "paddle_dy.pdparams")
858

859
                para_state_dict = paddle.load("paddle_dy.pdparams")
860
                emb.set_state_dict(para_state_dict)
861 862 863

        '''

864 865
        self._layers.set_state_dict(state_dict,
                                    use_structured_name=use_structured_name)
866 867 868 869

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict