parallel.py 32.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
S
ShenLiang 已提交
20 21
import itertools
import warnings
22
from contextlib import contextmanager
23

S
ShenLiang 已提交
24
import paddle
25
from paddle import _C_ops
26 27 28 29 30 31
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
32
from ..layers import collective
33
from paddle.fluid.dygraph import base as imperative_base
34
from paddle.fluid.framework import ParamBase, _in_legacy_dygraph, _non_static_mode, in_dygraph_mode
35

36
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
37 38 39 40

ParallelStrategy = core.ParallelStrategy


41
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
42
def prepare_context(strategy=None):
43 44 45
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
46 47 48 49 50 51 52 53
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
J
Jiabin Yang 已提交
54
    assert framework._non_static_mode() is True, \
55
        "dygraph.prepare_context should be used with dygraph mode."
56
    place = framework._current_expected_place()
C
chengduo 已提交
57
    assert place is not None, \
58
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
59 60 61 62
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
63 64 65
        elif isinstance(place, core.XPUPlace):
            parallel_helper._set_parallel_ctx(
                core.BKCLParallelContext(strategy, place))
66 67 68
        elif isinstance(place, core.NPUPlace):
            parallel_helper._set_parallel_ctx(
                core.HCCLParallelContext(strategy, place))
69 70
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
71
            assert ("Only support CUDAPlace or XPUPlace or NPUPlace for now.")
72
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
73
    return strategy
74 75


76 77
class ParallelEnv(object):
    """
78 79 80 81
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
82 83

    This class is used to obtain the environment variables required for 
84
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
85

86
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
87
    or ``paddle.distributed.spawn`` .
88 89 90 91

    Examples:
      .. code-block:: python

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
116 117
    """

118
    def __init__(self):
119 120
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
121

122 123 124 125 126 127 128
        # imperative only support one gpu or xpu
        if core.is_compiled_with_cuda():
            selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
            self._device_id = int(selected_gpus[0])
        elif core.is_compiled_with_xpu():
            selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
            self._device_id = int(selected_xpus[0])
129 130 131
        elif core.is_compiled_with_npu():
            selected_npus = os.getenv("FLAGS_selected_npus", "0").split(",")
            self._device_id = int(selected_npus[0])
132 133 134
        elif core.is_compiled_with_mlu():
            selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",")
            self._device_id = int(selected_mlus[0])
135

136 137 138
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
139 140 141 142 143
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert self._nrings > 0, \
            "nccl_nrings must be an integer greater than 0."
        assert self._nrings < 9, \
            "nccl_nrings should be less than 9, which is enough in most scenarios."
144 145

    @property
146
    def rank(self):
147
        """
148
        Rank of current trainer.
149

150
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
151 152 153 154

        Examples:
          .. code-block:: python

155 156
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
157
            
158 159 160
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
161
        """
162
        return self._rank
163 164

    @property
165
    def world_size(self):
166
        """
167
        The number of trainers (number of processes participating in current job).
168

169
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
170 171 172 173

        Examples:
          .. code-block:: python

174 175
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
176
            
177 178 179
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
180
        """
181
        return self._world_size
182 183

    @property
184
    def device_id(self):
185 186 187
        """
        The ID of selected GPU card for parallel training.

188
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
189 190 191 192 193

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
194
            import paddle.distributed as dist
195
            
196 197
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
198 199
            # The device id are 1
        """
200
        return self._device_id
201 202 203

    @property
    def current_endpoint(self):
204 205 206
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

207
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
208 209 210 211 212

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
213
            import paddle.distributed as dist
214
            
215
            env = dist.ParallelEnv()
216 217 218
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
219
        return self._current_endpoint
220 221 222

    @property
    def trainer_endpoints(self):
223 224 225 226
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

227
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
228 229 230 231 232

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
233
            import paddle.distributed as dist
234
            
235
            env = dist.ParallelEnv()
236 237 238
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
239 240
        return self._trainer_endpoints

241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
            
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

260 261 262 263 264
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

265

266 267 268 269 270 271
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
290
                nn.reshape(x=g_var, shape=[np.prod(g_var.shape)]))
291 292 293 294 295 296 297 298 299
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
300 301 302 303 304 305 306
    framework._dygraph_tracer().trace_op(type="reshape2",
                                         inputs={'X': x},
                                         outputs={
                                             'Out': x,
                                             'XShape': x_shape
                                         },
                                         attrs={'shape': shape})
307 308 309 310


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
311 312 313 314 315 316 317
    if _in_legacy_dygraph():
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            framework._dygraph_tracer().trace_op(
                type='split',
                inputs={'X': coalesced_grad},
                outputs={'Out': origin_grad_vars},
318 319 320 321
                attrs={
                    'sections': grad_var_len,
                    'axis': 0
                })
322 323 324 325 326 327 328 329 330 331 332 333 334
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                _reshape_inplace(x=g_var, shape=g_shape)
                assert g_var.shape == g_shape
    elif in_dygraph_mode():
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            attrs = ()
            attrs += ('sections', grad_var_len)
            attrs += ('axis', 0)
            _C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                g_var.reshape_(shape=g_shape)
                assert g_var.shape == g_shape
335 336 337


def scale_loss(loss):
338
    # TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
339 340 341 342 343 344 345 346 347 348
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


349 350
@imperative_base.no_grad
@framework.dygraph_only
351
def build_groups(vars, group_size):
352 353 354 355 356 357 358 359 360 361
    group_idx = 0
    memory_counter = 0
    var_groups = OrderedDict()
    dtype = vars[0].dtype

    for var in vars:
        bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
        if memory_counter < group_size and dtype == var.dtype:
            memory_counter += bytes
        else:
362
            memory_counter = bytes
363 364 365 366 367 368 369 370 371 372 373 374 375
            dtype = var.dtype
            group_idx += 1
        var_groups.setdefault(group_idx, []).append(var)
    return _coalesce_tensors(var_groups)


@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(model,
                        comm_group=None,
                        src_rank=0,
                        is_model_parallel=False):
    model_vars = []
376
    for _, param in model._obtain_parameters_buffers().items():
377 378 379 380
        if not isinstance(param, (core.VarBase, core.eager.Tensor)):
            raise TypeError(
                "The data type of '%s' must be Varbase or eager.Tensor" %
                param.name)
381

382
        # is_distributed param not need to sync when in mp mode
383
        if isinstance(param, (ParamBase, core.eager.Tensor)):
384 385 386
            if is_model_parallel and param.is_distributed:
                continue

387
            # NOTE(shenliang03): Support situations that do not require synchronization parameters,
388 389
            # such as moe's expert parameters
            if getattr(param, "no_sync", False):
S
ShenLiang 已提交
390
                continue
391 392
        if param.type == core.VarDesc.VarType.VOCAB:
            continue
393 394 395 396 397 398

        model_vars.append(param.detach())
    if len(model_vars) == 0:
        return

    # group size is 128M
399
    coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
400 401

    for coalesced_var, _, _ in coalesced_vars:
402 403 404 405
        paddle.distributed.broadcast(coalesced_var,
                                     src=src_rank,
                                     group=comm_group,
                                     use_calc_stream=True)
406 407 408 409 410 411 412

    for coalesced_var, origin_vars, var_shapes in coalesced_vars:
        var_len = [np.prod(v_shape) for v_shape in var_shapes]
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_var},
            outputs={'Out': origin_vars},
413 414 415 416
            attrs={
                'sections': var_len,
                'axis': 0
            })
417 418


419
class DataParallel(layers.Layer):
C
chengduo 已提交
420
    """
421
    Run the dygraph module with data parallelism.
C
chengduo 已提交
422

423
    Currently, DataParallel class only supports to run the dynamic graph
424 425 426 427 428 429 430 431 432 433
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
434
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
435 436

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
437

438 439
    Args:
        layers(Layer): The module that should be executed by data parallel.
440 441
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
442
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer  
443 444
                                          parameters' gradient which is the input of communication 
                                          calling(e.g NCCLAllReduce). Default: 25.
445 446
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to 
447
                                         improve performance. Default: 1.
448 449 450 451 452 453 454 455 456 457 458
        find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
                                                all tensors in the return value of the wrapped model's 
                                                forward function. For parameters not involved in loss 
                                                calculation, their gradients will be marked as ready in 
                                                advance to prepare reduce. Please note that all forward 
                                                outputs derived from the wrapped model parameters must 
                                                participate in the calculation of loss and subsequent 
                                                gradient calculations. If not, serious error will occur.
                                                Note that setting the find_unused_parameters to True 
                                                will affect computing performance. Therefore, if all parameters
                                                are sure to participate in the loss calculation and the 
459
                                                autograd graph construction, please set it False. Default: False.
460
            
461 462 463
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
464
    Examples:
465

C
chengduo 已提交
466
        .. code-block:: python
467 468
            :name: dp-example

469
            # required: distributed
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
485
                # 1. initialize parallel environment
486 487
                dist.init_parallel_env()

488
                # 2. create data parallel layer & optimizer
489 490 491 492 493 494 495
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

496
                # 3. run layer
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577


    .. note::
        ``PyLayer`` is not supported in DataParallel. To solve problems of this kind, 
        it's recommended to skip gradient synchronization among multiple cards by 'no_sync', 
        and manually implement 'all_reduce' before model optimization. There is an example 
        showing specific implemetation processing.

    Examples:

        .. code-block:: python
            :name: dp-pylayer-example

            # required: distributed
            import numpy
            import paddle
            import paddle.distributed as dist
            from paddle.autograd import PyLayer
            from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad

            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super(SimpleNet, self).__init__()
                    self.linear = paddle.nn.Linear(2, 2)

                def forward(self, inputs):
                    inputs = cus_tanh.apply(inputs)
                    return self.linear(inputs)

            if __name__ == '__main__':
                dist.init_parallel_env()

                model = SimpleNet()
                model = paddle.DataParallel(model)
                opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

                for step in range(10):
                    x_data = numpy.random.randn(2, 2).astype(numpy.float32)
                    x = paddle.to_tensor(x_data)
                    x.stop_gradient = False

                    # step 1 : skip gradient synchronization by 'no_sync'
                    with model.no_sync():
                        y_pred = model(x)
                        loss = y_pred.mean()
                        loss.backward()

                    # step 2 : fuse + allreduce manually before optimization
                    fused_allreduce_gradients(list(model.parameters()), None)

                    opt.step()
                    opt.clear_grad()

C
chengduo 已提交
578 579
    """

580 581 582
    def __init__(self,
                 layers,
                 strategy=None,
583
                 comm_buffer_size=25,
584
                 last_comm_buffer_size=1,
585
                 find_unused_parameters=False,
586
                 group=None):
587 588
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
589

590 591 592
        assert _non_static_mode(), \
            "It's not supported to construct DataParallel in static mode."

593
        self._layers = layers
594
        self.find_unused_parameters = find_unused_parameters
595
        self.grad_need_sync = True
596
        self.group = group
597
        self.var_dtype = core.eager.Tensor if in_dygraph_mode(
J
Jiabin Yang 已提交
598
        ) else core.VarBase
599

600 601
        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
        # It just stores some environment variables, which can be constructed by
602 603 604 605 606
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
607
            self._strategy = _build_default_parallel_strategy()
608

609
        if self._strategy.nranks > 1:
610 611 612 613 614
            # check the environment
            assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

615 616 617 618 619 620
            if in_dygraph_mode():
                self.group = paddle.distributed.collective._get_default_group(
                ) if self.group is None else self.group

                assert isinstance(self.group, paddle.distributed.collective.Group), \
                    "ProcessGroup must be an instance of Group in DataParallel."
621

622
            # sync buffer and params
623
            # TODO(liuyuhui) Currently not support xpu. xpu is
624 625
            # still broadcasting parameters when calling layer
            if not paddle.is_compiled_with_xpu():
626
                sync_params_buffers(self._layers)
627

628
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
629 630 631
            # NOTE(shenliang03): We can set environment variables to control
            # the size of the group, Default: 1MB. The role of this small group is:
            # when the last group allreduce, the overlap cannot work. Making the
632
            # the last group small is useful to improve performance.
633 634
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
635 636
            self.init_reducer()
        else:
S
ShenLiang 已提交
637 638
            warnings.warn("The program will return to single-card operation. "
                          "Please check 1, whether you use spawn or fleetrun "
639 640
                          "to start the program. 2, Whether it is a multi-card "
                          "program. 3, Is the current environment multi-card.")
641 642 643 644 645 646 647 648 649

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
650 651 652
                if not isinstance(param, self.var_dtype):
                    raise TypeError("The data type of '%s' must be '%s'" %
                                    (param.name, self.var_dtype))
653 654 655 656 657
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

658 659 660 661
        assert len(trainable_parameters) > 0, \
            "This model does not have any parameters to train, and " \
            "does not need to use DataParallel"

662 663 664
        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
665
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
666
        def check_layer_sparse(sublayer):
667 668
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
669
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding
670
            # is removed in the future, the check will also be removed here.
671
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
672 673 674 675 676 677 678
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

679
        if in_dygraph_mode():
680 681 682 683 684
            self.group_indices = core.eager_assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])

            self._reducer = core.EagerReducer(
685 686
                trainable_parameters, list(reversed(self.group_indices)),
                is_sparse_gradient, self.group.process_group,
687 688
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
689
        elif _in_legacy_dygraph():
690 691 692
            self.group_indices = core.assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])
693

694
            self._reducer = core.Reducer(
695 696
                trainable_parameters, list(reversed(self.group_indices)),
                is_sparse_gradient, parallel_helper.__parallel_ctx__clz__,
697 698
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
699 700

    def _find_varbase(self, obj):
701
        var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase
702
        if isinstance(obj, var_type):
703 704 705 706 707 708
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
709

710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
    @contextmanager
    def no_sync(self):
        """
        A context manager to stop gradient synchronization. Within no_sync(), 
        gradients of parameters will only be accumulated on model and not 
        synchronized util the first forward-backward out of this context.

        Examples:
            .. code-block:: python

                # required: distributed
                import paddle
                import paddle.nn as nn
                import paddle.distributed as dist

                class SimpleNet(nn.Layer):
                    def __init__(self):
                        super(SimpleNet, self).__init__()
                        self._linear = nn.Linear(10, 1)
                        
                    def forward(self, x):
                        return self._linear(x)

                dist.init_parallel_env()
                model = SimpleNet()
                dp_model = paddle.DataParallel(model)

                inputs_1 = paddle.randn([10, 10], 'float32')
                inputs_2 = paddle.ones([10, 10], 'float32')

                with dp_model.no_sync():
                    # gradients will not be synchronized
                    dp_model(inputs_1).backward()

                # synchronization happens here
                dp_model(inputs_2).backward()

        """
        tmp_grad_need_sync = self.grad_need_sync
        self.grad_need_sync = False
        try:
            yield
        finally:
            self.grad_need_sync = tmp_grad_need_sync

755
    def forward(self, *inputs, **kwargs):
756
        outputs = self._layers(*inputs, **kwargs)
757 758
        if self._strategy.nranks > 1 and framework._dygraph_tracer(
        )._has_grad and self.grad_need_sync:
759 760
            self._reducer.prepare_for_backward(list(
                self._find_varbase(outputs)))
761
        return outputs
Y
Yan Xu 已提交
762

763 764
    @deprecated(since="2.0.0",
                reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
765
    def scale_loss(self, loss):
C
chengduo 已提交
766
        """
767 768
        Deprecated method, now ``scale_loss`` is an empty method,  
        keep this method just for compatibility.
C
chengduo 已提交
769
        """
Y
Yan Xu 已提交
770 771
        return loss

772 773
    @deprecated(since="2.0.0",
                reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
774
    def apply_collective_grads(self):
C
chengduo 已提交
775
        """
776 777
        Deprecated method, now ``apply_collective_grads`` is an empty method, 
        keep this method just for compatibility.
C
chengduo 已提交
778
        """
779
        return
780 781 782 783 784 785

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
786
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
787 788

        Parameters:
789 790
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
791 792

        Retruns:
793
            dict: a dict contains all the parameters and persistable buffers.
794 795 796 797

        Examples:
            .. code-block:: python

798 799 800 801 802 803 804
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
805

806 807
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
808 809 810 811 812 813 814 815

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

816
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
817
    def set_state_dict(self, state_dict, use_structured_name=True):
818
        '''
819
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
820 821

        Parameters:
822 823
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
824 825 826 827 828 829 830
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

831 832
                import paddle
                import paddle.distributed as dist
833

834
                dist.init_parallel_env()
835

836
                emb = paddle.nn.Embedding(10, 10)
837
                emb = fluid.dygraph.DataParallel(emb)
838

839
                state_dict = emb.state_dict()
840
                paddle.save(state_dict, "paddle_dy.pdparams")
841

842
                para_state_dict = paddle.load("paddle_dy.pdparams")
843
                emb.set_state_dict(para_state_dict)
844 845 846

        '''

847 848
        self._layers.set_state_dict(state_dict,
                                    use_structured_name=use_structured_name)
849 850 851 852

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict