parallel.py 32.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
S
ShenLiang 已提交
20 21
import itertools
import warnings
22
from contextlib import contextmanager
23

S
ShenLiang 已提交
24
import paddle
25 26 27 28 29 30
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
31
from ..layers import collective
32
from paddle.fluid.dygraph import base as imperative_base
33
from paddle.fluid.framework import ParamBase, _in_legacy_dygraph, _non_static_mode, in_dygraph_mode
34

35
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
36 37 38 39

ParallelStrategy = core.ParallelStrategy


40
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
41
def prepare_context(strategy=None):
42 43 44
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
45 46 47 48 49 50 51 52
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
J
Jiabin Yang 已提交
53
    assert framework._non_static_mode() is True, \
54
        "dygraph.prepare_context should be used with dygraph mode."
55
    place = framework._current_expected_place()
C
chengduo 已提交
56
    assert place is not None, \
57
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
58 59 60 61
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
62 63 64
        elif isinstance(place, core.XPUPlace):
            parallel_helper._set_parallel_ctx(
                core.BKCLParallelContext(strategy, place))
65 66 67
        elif isinstance(place, core.NPUPlace):
            parallel_helper._set_parallel_ctx(
                core.HCCLParallelContext(strategy, place))
68 69
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
70
            assert ("Only support CUDAPlace or XPUPlace or NPUPlace for now.")
71
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
72
    return strategy
73 74


75 76
class ParallelEnv(object):
    """
77 78 79 80
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
81 82

    This class is used to obtain the environment variables required for 
83
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
84

85
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
86
    or ``paddle.distributed.spawn`` .
87 88 89 90

    Examples:
      .. code-block:: python

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
115 116
    """

117
    def __init__(self):
118 119
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
120

121 122 123 124 125 126 127
        # imperative only support one gpu or xpu
        if core.is_compiled_with_cuda():
            selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
            self._device_id = int(selected_gpus[0])
        elif core.is_compiled_with_xpu():
            selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
            self._device_id = int(selected_xpus[0])
128 129 130
        elif core.is_compiled_with_npu():
            selected_npus = os.getenv("FLAGS_selected_npus", "0").split(",")
            self._device_id = int(selected_npus[0])
131 132 133
        elif core.is_compiled_with_mlu():
            selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",")
            self._device_id = int(selected_mlus[0])
134

135 136 137
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
138 139 140 141 142
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert self._nrings > 0, \
            "nccl_nrings must be an integer greater than 0."
        assert self._nrings < 9, \
            "nccl_nrings should be less than 9, which is enough in most scenarios."
143 144

    @property
145
    def rank(self):
146
        """
147
        Rank of current trainer.
148

149
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
150 151 152 153

        Examples:
          .. code-block:: python

154 155
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
156
            
157 158 159
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
160
        """
161
        return self._rank
162 163

    @property
164
    def world_size(self):
165
        """
166
        The number of trainers (number of processes participating in current job).
167

168
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
169 170 171 172

        Examples:
          .. code-block:: python

173 174
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
175
            
176 177 178
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
179
        """
180
        return self._world_size
181 182

    @property
183
    def device_id(self):
184 185 186
        """
        The ID of selected GPU card for parallel training.

187
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
188 189 190 191 192

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
193
            import paddle.distributed as dist
194
            
195 196
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
197 198
            # The device id are 1
        """
199
        return self._device_id
200 201 202

    @property
    def current_endpoint(self):
203 204 205
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

206
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
207 208 209 210 211

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
212
            import paddle.distributed as dist
213
            
214
            env = dist.ParallelEnv()
215 216 217
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
218
        return self._current_endpoint
219 220 221

    @property
    def trainer_endpoints(self):
222 223 224 225
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

226
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
227 228 229 230 231

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
232
            import paddle.distributed as dist
233
            
234
            env = dist.ParallelEnv()
235 236 237
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
238 239
        return self._trainer_endpoints

240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
            
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

259 260 261 262 263
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

264

265 266 267 268 269 270
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
                nn.reshape(
                    x=g_var, shape=[np.prod(g_var.shape)]))
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
    framework._dygraph_tracer().trace_op(
        type="reshape2",
        inputs={'X': x},
        outputs={'Out': x,
                 'XShape': x_shape},
        attrs={'shape': shape})


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
    for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
        grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
        framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_grad},
            outputs={'Out': origin_grad_vars},
            attrs={'sections': grad_var_len,
                   'axis': 0})
        for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
            _reshape_inplace(x=g_var, shape=g_shape)
            assert g_var.shape == g_shape


def scale_loss(loss):
324
    # TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
325 326 327 328 329 330 331 332 333 334
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


335 336
@imperative_base.no_grad
@framework.dygraph_only
337
def build_groups(vars, group_size):
338 339 340 341 342 343 344 345 346 347
    group_idx = 0
    memory_counter = 0
    var_groups = OrderedDict()
    dtype = vars[0].dtype

    for var in vars:
        bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
        if memory_counter < group_size and dtype == var.dtype:
            memory_counter += bytes
        else:
348
            memory_counter = bytes
349 350 351 352 353 354 355 356 357 358 359 360 361
            dtype = var.dtype
            group_idx += 1
        var_groups.setdefault(group_idx, []).append(var)
    return _coalesce_tensors(var_groups)


@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(model,
                        comm_group=None,
                        src_rank=0,
                        is_model_parallel=False):
    model_vars = []
362
    for _, param in model._obtain_parameters_buffers().items():
363 364 365 366
        if not isinstance(param, (core.VarBase, core.eager.Tensor)):
            raise TypeError(
                "The data type of '%s' must be Varbase or eager.Tensor" %
                param.name)
367

368
        # is_distributed param not need to sync when in mp mode
369 370 371 372 373 374 375
        if isinstance(param, ParamBase):
            if is_model_parallel and param.is_distributed:
                continue

            # NOTE(shenliang03): Support situations that do not require synchronization parameters, 
            # such as moe's expert parameters
            if getattr(param, "no_sync", False):
S
ShenLiang 已提交
376
                continue
377 378
        if param.type == core.VarDesc.VarType.VOCAB:
            continue
379 380 381 382 383 384

        model_vars.append(param.detach())
    if len(model_vars) == 0:
        return

    # group size is 128M
385
    coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400

    for coalesced_var, _, _ in coalesced_vars:
        paddle.distributed.broadcast(
            coalesced_var, src=src_rank, group=comm_group, use_calc_stream=True)

    for coalesced_var, origin_vars, var_shapes in coalesced_vars:
        var_len = [np.prod(v_shape) for v_shape in var_shapes]
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_var},
            outputs={'Out': origin_vars},
            attrs={'sections': var_len,
                   'axis': 0})


401 402 403 404 405 406 407 408 409 410
@imperative_base.no_grad
@framework.dygraph_only
def sync_eager_params(model, comm_group=None, src_rank=0):
    for _, param in model._obtain_parameters_buffers().items():
        if not isinstance(param, core.eager.Tensor):
            raise TypeError("The data type of '%s' must be '%s'" %
                            (param.name, core.eager.Tensor))
        comm_group.broadcast(param, src_rank).synchronize()


411
class DataParallel(layers.Layer):
C
chengduo 已提交
412
    """
413
    Run the dygraph module with data parallelism.
C
chengduo 已提交
414

415
    Currently, DataParallel class only supports to run the dynamic graph
416 417 418 419 420 421 422 423 424 425
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
426
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
427 428

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
429

430 431
    Args:
        layers(Layer): The module that should be executed by data parallel.
432 433
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
434
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer  
435 436
                                          parameters' gradient which is the input of communication 
                                          calling(e.g NCCLAllReduce). Default: 25.
437 438
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to 
439
                                         improve performance. Default: 1.
440 441 442 443 444 445 446 447 448 449 450
        find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
                                                all tensors in the return value of the wrapped model's 
                                                forward function. For parameters not involved in loss 
                                                calculation, their gradients will be marked as ready in 
                                                advance to prepare reduce. Please note that all forward 
                                                outputs derived from the wrapped model parameters must 
                                                participate in the calculation of loss and subsequent 
                                                gradient calculations. If not, serious error will occur.
                                                Note that setting the find_unused_parameters to True 
                                                will affect computing performance. Therefore, if all parameters
                                                are sure to participate in the loss calculation and the 
451
                                                autograd graph construction, please set it False. Default: False.
452
            
453 454 455
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
456
    Examples:
457

C
chengduo 已提交
458
        .. code-block:: python
459 460
            :name: dp-example

461
            # required: distributed
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
477
                # 1. initialize parallel environment
478 479
                dist.init_parallel_env()

480
                # 2. create data parallel layer & optimizer
481 482 483 484 485 486 487
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

488
                # 3. run layer
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569


    .. note::
        ``PyLayer`` is not supported in DataParallel. To solve problems of this kind, 
        it's recommended to skip gradient synchronization among multiple cards by 'no_sync', 
        and manually implement 'all_reduce' before model optimization. There is an example 
        showing specific implemetation processing.

    Examples:

        .. code-block:: python
            :name: dp-pylayer-example

            # required: distributed
            import numpy
            import paddle
            import paddle.distributed as dist
            from paddle.autograd import PyLayer
            from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad

            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super(SimpleNet, self).__init__()
                    self.linear = paddle.nn.Linear(2, 2)

                def forward(self, inputs):
                    inputs = cus_tanh.apply(inputs)
                    return self.linear(inputs)

            if __name__ == '__main__':
                dist.init_parallel_env()

                model = SimpleNet()
                model = paddle.DataParallel(model)
                opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

                for step in range(10):
                    x_data = numpy.random.randn(2, 2).astype(numpy.float32)
                    x = paddle.to_tensor(x_data)
                    x.stop_gradient = False

                    # step 1 : skip gradient synchronization by 'no_sync'
                    with model.no_sync():
                        y_pred = model(x)
                        loss = y_pred.mean()
                        loss.backward()

                    # step 2 : fuse + allreduce manually before optimization
                    fused_allreduce_gradients(list(model.parameters()), None)

                    opt.step()
                    opt.clear_grad()

C
chengduo 已提交
570 571
    """

572 573 574
    def __init__(self,
                 layers,
                 strategy=None,
575
                 comm_buffer_size=25,
576
                 last_comm_buffer_size=1,
577
                 find_unused_parameters=False,
578
                 process_group=None):
579 580
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
581

582 583 584
        assert _non_static_mode(), \
            "It's not supported to construct DataParallel in static mode."

585
        self._layers = layers
586
        self.find_unused_parameters = find_unused_parameters
587
        self.grad_need_sync = True
588
        self.process_group = process_group
589
        self.var_dtype = core.eager.Tensor if in_dygraph_mode(
J
Jiabin Yang 已提交
590
        ) else core.VarBase
591 592 593 594 595 596 597 598

        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. 
        # It just stores some environment variables, which can be constructed by 
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
599
            self._strategy = _build_default_parallel_strategy()
600

601
        if self._strategy.nranks > 1:
602 603 604 605 606
            # check the environment
            assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

607
            if self.process_group is None and in_dygraph_mode():
608
                raise RuntimeError(
609
                    "Process group should be built for DataParallel in eager mode."
610 611
                )

612 613 614 615
            # sync buffer and params
            # TODO(liuyuhui) Currently not support xpu. xpu is 
            # still broadcasting parameters when calling layer
            if not paddle.is_compiled_with_xpu():
616
                if in_dygraph_mode():
617 618
                    sync_eager_params(
                        self._layers, comm_group=self.process_group)
619
                elif _in_legacy_dygraph():
620
                    sync_params_buffers(self._layers)
621

622
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
623 624 625 626
            # NOTE(shenliang03): We can set environment variables to control 
            # the size of the group, Default: 1MB. The role of this small group is: 
            # when the last group allreduce, the overlap cannot work. Making the 
            # the last group small is useful to improve performance.
627 628
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
629 630
            self.init_reducer()
        else:
S
ShenLiang 已提交
631 632
            warnings.warn("The program will return to single-card operation. "
                          "Please check 1, whether you use spawn or fleetrun "
633 634
                          "to start the program. 2, Whether it is a multi-card "
                          "program. 3, Is the current environment multi-card.")
635 636 637 638 639 640 641 642 643

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
644 645 646
                if not isinstance(param, self.var_dtype):
                    raise TypeError("The data type of '%s' must be '%s'" %
                                    (param.name, self.var_dtype))
647 648 649 650 651
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

652 653 654 655
        assert len(trainable_parameters) > 0, \
            "This model does not have any parameters to train, and " \
            "does not need to use DataParallel"

656 657 658
        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
659
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
660
        def check_layer_sparse(sublayer):
661 662 663
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding 
664
            # is removed in the future, the check will also be removed here.
665
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
666 667 668 669 670 671 672
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

673
        if in_dygraph_mode():
674 675 676 677 678 679 680 681 682 683
            self.group_indices = core.eager_assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])

            self._reducer = core.EagerReducer(
                trainable_parameters,
                list(reversed(self.group_indices)), is_sparse_gradient,
                self.process_group,
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
684
        elif _in_legacy_dygraph():
685 686 687
            self.group_indices = core.assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])
688

689 690 691 692 693 694
            self._reducer = core.Reducer(
                trainable_parameters,
                list(reversed(self.group_indices)), is_sparse_gradient,
                parallel_helper.__parallel_ctx__clz__,
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
695 696

    def _find_varbase(self, obj):
697
        var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase
698
        if isinstance(obj, var_type):
699 700 701 702 703 704
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
705

706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
    @contextmanager
    def no_sync(self):
        """
        A context manager to stop gradient synchronization. Within no_sync(), 
        gradients of parameters will only be accumulated on model and not 
        synchronized util the first forward-backward out of this context.

        Examples:
            .. code-block:: python

                # required: distributed
                import paddle
                import paddle.nn as nn
                import paddle.distributed as dist

                class SimpleNet(nn.Layer):
                    def __init__(self):
                        super(SimpleNet, self).__init__()
                        self._linear = nn.Linear(10, 1)
                        
                    def forward(self, x):
                        return self._linear(x)

                dist.init_parallel_env()
                model = SimpleNet()
                dp_model = paddle.DataParallel(model)

                inputs_1 = paddle.randn([10, 10], 'float32')
                inputs_2 = paddle.ones([10, 10], 'float32')

                with dp_model.no_sync():
                    # gradients will not be synchronized
                    dp_model(inputs_1).backward()

                # synchronization happens here
                dp_model(inputs_2).backward()

        """
        tmp_grad_need_sync = self.grad_need_sync
        self.grad_need_sync = False
        try:
            yield
        finally:
            self.grad_need_sync = tmp_grad_need_sync

751
    def forward(self, *inputs, **kwargs):
752
        outputs = self._layers(*inputs, **kwargs)
753 754
        if self._strategy.nranks > 1 and framework._dygraph_tracer(
        )._has_grad and self.grad_need_sync:
755 756
            self._reducer.prepare_for_backward(
                list(self._find_varbase(outputs)))
757
        return outputs
Y
Yan Xu 已提交
758

759 760
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
761
    def scale_loss(self, loss):
C
chengduo 已提交
762
        """
763 764
        Deprecated method, now ``scale_loss`` is an empty method,  
        keep this method just for compatibility.
C
chengduo 已提交
765
        """
Y
Yan Xu 已提交
766 767
        return loss

768 769
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
770
    def apply_collective_grads(self):
C
chengduo 已提交
771
        """
772 773
        Deprecated method, now ``apply_collective_grads`` is an empty method, 
        keep this method just for compatibility.
C
chengduo 已提交
774
        """
775
        return
776 777 778 779 780 781

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
782
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
783 784

        Parameters:
785 786
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
787 788

        Retruns:
789
            dict: a dict contains all the parameters and persistable buffers.
790 791 792 793

        Examples:
            .. code-block:: python

794 795 796 797 798 799 800
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
801

802 803
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
804 805 806 807 808 809 810 811

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

812
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
813
    def set_state_dict(self, state_dict, use_structured_name=True):
814
        '''
815
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
816 817

        Parameters:
818 819
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
820 821 822 823 824 825 826
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

827 828
                import paddle
                import paddle.distributed as dist
829

830
                dist.init_parallel_env()
831

832
                emb = paddle.nn.Embedding(10, 10)
833
                emb = fluid.dygraph.DataParallel(emb)
834

835
                state_dict = emb.state_dict()
836
                paddle.save(state_dict, "paddle_dy.pdparams")
837

838
                para_state_dict = paddle.load("paddle_dy.pdparams")
839
                emb.set_state_dict(para_state_dict)
840 841 842

        '''

843
        self._layers.set_state_dict(
J
Jiabin Yang 已提交
844
            state_dict, use_structured_name=use_structured_name)
845 846 847 848

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict