parallel.py 33.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
S
ShenLiang 已提交
20 21
import itertools
import warnings
22
from contextlib import contextmanager
23

S
ShenLiang 已提交
24
import paddle
25
from paddle import _C_ops, _legacy_C_ops
26 27 28 29 30 31
from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
32
from ..layers import collective
33
from paddle.fluid.dygraph import base as imperative_base
34
from paddle.fluid.framework import ParamBase, _in_legacy_dygraph, _non_static_mode, in_dygraph_mode
35

36
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
37 38 39 40

ParallelStrategy = core.ParallelStrategy


41
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
42
def prepare_context(strategy=None):
43 44 45
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
46 47 48 49 50 51 52 53
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
J
Jiabin Yang 已提交
54
    assert framework._non_static_mode() is True, \
55
        "dygraph.prepare_context should be used with dygraph mode."
56
    place = framework._current_expected_place()
C
chengduo 已提交
57
    assert place is not None, \
58
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
59 60 61 62
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
63 64 65
        elif isinstance(place, core.XPUPlace):
            parallel_helper._set_parallel_ctx(
                core.BKCLParallelContext(strategy, place))
66 67 68
        elif isinstance(place, core.NPUPlace):
            parallel_helper._set_parallel_ctx(
                core.HCCLParallelContext(strategy, place))
69 70
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
71
            assert ("Only support CUDAPlace or XPUPlace or NPUPlace for now.")
72
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
73
    return strategy
74 75


76 77
class ParallelEnv(object):
    """
78 79 80 81
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
82 83

    This class is used to obtain the environment variables required for 
84
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
85

86
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
87
    or ``paddle.distributed.spawn`` .
88 89 90 91

    Examples:
      .. code-block:: python

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
116 117
    """

118
    def __init__(self):
119 120
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
121
        self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", ""))
122

123
        # imperative only support one gpu or xpu
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        if self._device_type != "":
            FLAGS_selected_custom_devices = 'FLAGS_selected_{}s'.format(
                self._device_type)
            selected_custom_devices = os.getenv(FLAGS_selected_custom_devices,
                                                "0").split(",")
            self._device_id = int(selected_custom_devices[0])
        else:
            if core.is_compiled_with_cuda():
                selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
                self._device_id = int(selected_gpus[0])
            elif core.is_compiled_with_xpu():
                selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
                self._device_id = int(selected_xpus[0])
            elif core.is_compiled_with_npu():
                selected_npus = os.getenv("FLAGS_selected_npus", "0").split(",")
                self._device_id = int(selected_npus[0])
            elif core.is_compiled_with_mlu():
                selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",")
                self._device_id = int(selected_mlus[0])
143

144 145 146
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
147 148 149 150 151
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert self._nrings > 0, \
            "nccl_nrings must be an integer greater than 0."
        assert self._nrings < 9, \
            "nccl_nrings should be less than 9, which is enough in most scenarios."
152 153

    @property
154
    def rank(self):
155
        """
156
        Rank of current trainer.
157

158
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
159 160 161 162

        Examples:
          .. code-block:: python

163 164
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
165
            
166 167 168
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
169
        """
170
        return self._rank
171 172

    @property
173
    def world_size(self):
174
        """
175
        The number of trainers (number of processes participating in current job).
176

177
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
178 179 180 181

        Examples:
          .. code-block:: python

182 183
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
184
            
185 186 187
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
188
        """
189
        return self._world_size
190 191

    @property
192
    def device_id(self):
193 194 195
        """
        The ID of selected GPU card for parallel training.

196
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
197 198 199 200 201

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
202
            import paddle.distributed as dist
203
            
204 205
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
206 207
            # The device id are 1
        """
208
        return self._device_id
209

210 211 212 213 214 215 216 217 218 219
    @property
    def device_type(self):
        """
        The type of custom device for parallel training.

        Its value is equal to the value of the environment variable ``PADDLE_XCCL_BACKEND`` . The default value is None.

        """
        return self._device_type

220 221
    @property
    def current_endpoint(self):
222 223 224
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

225
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
226 227 228 229 230

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
231
            import paddle.distributed as dist
232
            
233
            env = dist.ParallelEnv()
234 235 236
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
237
        return self._current_endpoint
238 239 240

    @property
    def trainer_endpoints(self):
241 242 243 244
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

245
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
246 247 248 249 250

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
251
            import paddle.distributed as dist
252
            
253
            env = dist.ParallelEnv()
254 255 256
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
257 258
        return self._trainer_endpoints

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
            
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

278 279 280 281 282
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

283

284 285 286 287 288 289
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
308
                nn.reshape(x=g_var, shape=[np.prod(g_var.shape)]))
309 310 311 312 313 314 315 316 317
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
318 319 320 321 322 323 324
    framework._dygraph_tracer().trace_op(type="reshape2",
                                         inputs={'X': x},
                                         outputs={
                                             'Out': x,
                                             'XShape': x_shape
                                         },
                                         attrs={'shape': shape})
325 326 327 328


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
329 330 331 332 333 334 335
    if _in_legacy_dygraph():
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            framework._dygraph_tracer().trace_op(
                type='split',
                inputs={'X': coalesced_grad},
                outputs={'Out': origin_grad_vars},
336 337 338 339
                attrs={
                    'sections': grad_var_len,
                    'axis': 0
                })
340 341 342 343 344 345 346 347 348
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                _reshape_inplace(x=g_var, shape=g_shape)
                assert g_var.shape == g_shape
    elif in_dygraph_mode():
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            attrs = ()
            attrs += ('sections', grad_var_len)
            attrs += ('axis', 0)
349
            _legacy_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
350 351 352
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                g_var.reshape_(shape=g_shape)
                assert g_var.shape == g_shape
353 354 355


def scale_loss(loss):
356
    # TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
357 358 359 360 361 362 363 364 365 366
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


367 368
@imperative_base.no_grad
@framework.dygraph_only
369
def build_groups(vars, group_size):
370 371 372 373 374 375 376 377 378 379
    group_idx = 0
    memory_counter = 0
    var_groups = OrderedDict()
    dtype = vars[0].dtype

    for var in vars:
        bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
        if memory_counter < group_size and dtype == var.dtype:
            memory_counter += bytes
        else:
380
            memory_counter = bytes
381 382 383 384 385 386 387 388 389 390 391 392 393
            dtype = var.dtype
            group_idx += 1
        var_groups.setdefault(group_idx, []).append(var)
    return _coalesce_tensors(var_groups)


@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(model,
                        comm_group=None,
                        src_rank=0,
                        is_model_parallel=False):
    model_vars = []
394
    for _, param in model._obtain_parameters_buffers().items():
395 396 397 398
        if not isinstance(param, (core.VarBase, core.eager.Tensor)):
            raise TypeError(
                "The data type of '%s' must be Varbase or eager.Tensor" %
                param.name)
399

400
        # is_distributed param not need to sync when in mp mode
401
        if isinstance(param, (ParamBase, core.eager.Tensor)):
402 403 404
            if is_model_parallel and param.is_distributed:
                continue

405
            # NOTE(shenliang03): Support situations that do not require synchronization parameters,
406 407
            # such as moe's expert parameters
            if getattr(param, "no_sync", False):
S
ShenLiang 已提交
408
                continue
409 410
        if param.type == core.VarDesc.VarType.VOCAB:
            continue
411 412 413 414 415 416

        model_vars.append(param.detach())
    if len(model_vars) == 0:
        return

    # group size is 128M
417
    coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
418 419

    for coalesced_var, _, _ in coalesced_vars:
420 421 422
        paddle.distributed.broadcast(coalesced_var,
                                     src=src_rank,
                                     group=comm_group,
423
                                     sync_op=True)
424 425 426 427 428 429 430

    for coalesced_var, origin_vars, var_shapes in coalesced_vars:
        var_len = [np.prod(v_shape) for v_shape in var_shapes]
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_var},
            outputs={'Out': origin_vars},
431 432 433 434
            attrs={
                'sections': var_len,
                'axis': 0
            })
435 436


437
class DataParallel(layers.Layer):
C
chengduo 已提交
438
    """
439
    Run the dygraph module with data parallelism.
C
chengduo 已提交
440

441
    Currently, DataParallel class only supports to run the dynamic graph
442 443 444 445 446 447 448 449 450 451
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
452
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
453 454

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
455

456 457
    Args:
        layers(Layer): The module that should be executed by data parallel.
458 459
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
460
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer  
461 462
                                          parameters' gradient which is the input of communication 
                                          calling(e.g NCCLAllReduce). Default: 25.
463 464
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to 
465
                                         improve performance. Default: 1.
466 467 468 469 470 471 472 473 474 475 476
        find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
                                                all tensors in the return value of the wrapped model's 
                                                forward function. For parameters not involved in loss 
                                                calculation, their gradients will be marked as ready in 
                                                advance to prepare reduce. Please note that all forward 
                                                outputs derived from the wrapped model parameters must 
                                                participate in the calculation of loss and subsequent 
                                                gradient calculations. If not, serious error will occur.
                                                Note that setting the find_unused_parameters to True 
                                                will affect computing performance. Therefore, if all parameters
                                                are sure to participate in the loss calculation and the 
477
                                                autograd graph construction, please set it False. Default: False.
478
            
479 480 481
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
482
    Examples:
483

C
chengduo 已提交
484
        .. code-block:: python
485 486
            :name: dp-example

487
            # required: distributed
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
503
                # 1. initialize parallel environment
504 505
                dist.init_parallel_env()

506
                # 2. create data parallel layer & optimizer
507 508 509 510 511 512 513
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

514
                # 3. run layer
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595


    .. note::
        ``PyLayer`` is not supported in DataParallel. To solve problems of this kind, 
        it's recommended to skip gradient synchronization among multiple cards by 'no_sync', 
        and manually implement 'all_reduce' before model optimization. There is an example 
        showing specific implemetation processing.

    Examples:

        .. code-block:: python
            :name: dp-pylayer-example

            # required: distributed
            import numpy
            import paddle
            import paddle.distributed as dist
            from paddle.autograd import PyLayer
            from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad

            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super(SimpleNet, self).__init__()
                    self.linear = paddle.nn.Linear(2, 2)

                def forward(self, inputs):
                    inputs = cus_tanh.apply(inputs)
                    return self.linear(inputs)

            if __name__ == '__main__':
                dist.init_parallel_env()

                model = SimpleNet()
                model = paddle.DataParallel(model)
                opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

                for step in range(10):
                    x_data = numpy.random.randn(2, 2).astype(numpy.float32)
                    x = paddle.to_tensor(x_data)
                    x.stop_gradient = False

                    # step 1 : skip gradient synchronization by 'no_sync'
                    with model.no_sync():
                        y_pred = model(x)
                        loss = y_pred.mean()
                        loss.backward()

                    # step 2 : fuse + allreduce manually before optimization
                    fused_allreduce_gradients(list(model.parameters()), None)

                    opt.step()
                    opt.clear_grad()

C
chengduo 已提交
596 597
    """

598 599 600
    def __init__(self,
                 layers,
                 strategy=None,
601
                 comm_buffer_size=25,
602
                 last_comm_buffer_size=1,
603
                 find_unused_parameters=False,
604
                 group=None):
605 606
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
607

608 609 610
        assert _non_static_mode(), \
            "It's not supported to construct DataParallel in static mode."

611
        self._layers = layers
612
        self.find_unused_parameters = find_unused_parameters
613
        self.grad_need_sync = True
614
        self.group = group
615
        self.var_dtype = core.eager.Tensor if in_dygraph_mode(
J
Jiabin Yang 已提交
616
        ) else core.VarBase
617

618 619
        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
        # It just stores some environment variables, which can be constructed by
620 621 622 623 624
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
625
            self._strategy = _build_default_parallel_strategy()
626

627
        if self._strategy.nranks > 1:
628 629 630 631 632
            # check the environment
            assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

633 634 635 636 637 638
            if in_dygraph_mode():
                self.group = paddle.distributed.collective._get_default_group(
                ) if self.group is None else self.group

                assert isinstance(self.group, paddle.distributed.collective.Group), \
                    "ProcessGroup must be an instance of Group in DataParallel."
639

640
            # sync buffer and params
641
            # TODO(liuyuhui) Currently not support xpu. xpu is
642 643
            # still broadcasting parameters when calling layer
            if not paddle.is_compiled_with_xpu():
644
                sync_params_buffers(self._layers)
645

646
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
647 648 649
            # NOTE(shenliang03): We can set environment variables to control
            # the size of the group, Default: 1MB. The role of this small group is:
            # when the last group allreduce, the overlap cannot work. Making the
650
            # the last group small is useful to improve performance.
651 652
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
653 654
            self.init_reducer()
        else:
S
ShenLiang 已提交
655 656
            warnings.warn("The program will return to single-card operation. "
                          "Please check 1, whether you use spawn or fleetrun "
657 658
                          "to start the program. 2, Whether it is a multi-card "
                          "program. 3, Is the current environment multi-card.")
659 660 661 662 663 664 665 666 667

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
668 669 670
                if not isinstance(param, self.var_dtype):
                    raise TypeError("The data type of '%s' must be '%s'" %
                                    (param.name, self.var_dtype))
671 672 673 674 675
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

676 677 678 679
        assert len(trainable_parameters) > 0, \
            "This model does not have any parameters to train, and " \
            "does not need to use DataParallel"

680 681 682
        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
683
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
684
        def check_layer_sparse(sublayer):
685 686
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
687
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding
688
            # is removed in the future, the check will also be removed here.
689
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
690 691 692 693 694 695 696
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

697
        if in_dygraph_mode():
698 699 700 701 702
            self.group_indices = core.eager_assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])

            self._reducer = core.EagerReducer(
703 704
                trainable_parameters, list(reversed(self.group_indices)),
                is_sparse_gradient, self.group.process_group,
705 706
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
707
        elif _in_legacy_dygraph():
708 709 710
            self.group_indices = core.assign_group_by_size(
                trainable_parameters, is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size])
711

712
            self._reducer = core.Reducer(
713 714
                trainable_parameters, list(reversed(self.group_indices)),
                is_sparse_gradient, parallel_helper.__parallel_ctx__clz__,
715 716
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters)
717 718

    def _find_varbase(self, obj):
719
        var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase
720
        if isinstance(obj, var_type):
721 722 723 724 725 726
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
727

728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
    @contextmanager
    def no_sync(self):
        """
        A context manager to stop gradient synchronization. Within no_sync(), 
        gradients of parameters will only be accumulated on model and not 
        synchronized util the first forward-backward out of this context.

        Examples:
            .. code-block:: python

                # required: distributed
                import paddle
                import paddle.nn as nn
                import paddle.distributed as dist

                class SimpleNet(nn.Layer):
                    def __init__(self):
                        super(SimpleNet, self).__init__()
                        self._linear = nn.Linear(10, 1)
                        
                    def forward(self, x):
                        return self._linear(x)

                dist.init_parallel_env()
                model = SimpleNet()
                dp_model = paddle.DataParallel(model)

                inputs_1 = paddle.randn([10, 10], 'float32')
                inputs_2 = paddle.ones([10, 10], 'float32')

                with dp_model.no_sync():
                    # gradients will not be synchronized
                    dp_model(inputs_1).backward()

                # synchronization happens here
                dp_model(inputs_2).backward()

        """
        tmp_grad_need_sync = self.grad_need_sync
        self.grad_need_sync = False
        try:
            yield
        finally:
            self.grad_need_sync = tmp_grad_need_sync

773
    def forward(self, *inputs, **kwargs):
774
        outputs = self._layers(*inputs, **kwargs)
775 776
        if self._strategy.nranks > 1 and framework._dygraph_tracer(
        )._has_grad and self.grad_need_sync:
777 778
            self._reducer.prepare_for_backward(list(
                self._find_varbase(outputs)))
779
        return outputs
Y
Yan Xu 已提交
780

781 782
    @deprecated(since="2.0.0",
                reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
783
    def scale_loss(self, loss):
C
chengduo 已提交
784
        """
785 786
        Deprecated method, now ``scale_loss`` is an empty method,  
        keep this method just for compatibility.
C
chengduo 已提交
787
        """
Y
Yan Xu 已提交
788 789
        return loss

790 791
    @deprecated(since="2.0.0",
                reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
792
    def apply_collective_grads(self):
C
chengduo 已提交
793
        """
794 795
        Deprecated method, now ``apply_collective_grads`` is an empty method, 
        keep this method just for compatibility.
C
chengduo 已提交
796
        """
797
        return
798 799 800 801 802 803

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
804
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
805 806

        Parameters:
807 808
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
809 810

        Retruns:
811
            dict: a dict contains all the parameters and persistable buffers.
812 813 814 815

        Examples:
            .. code-block:: python

816 817 818 819 820 821 822
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
823

824 825
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
826 827 828 829 830 831 832 833

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

834
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
835
    def set_state_dict(self, state_dict, use_structured_name=True):
836
        '''
837
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
838 839

        Parameters:
840 841
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
842 843 844 845 846 847 848
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

849 850
                import paddle
                import paddle.distributed as dist
851

852
                dist.init_parallel_env()
853

854
                emb = paddle.nn.Embedding(10, 10)
855
                emb = fluid.dygraph.DataParallel(emb)
856

857
                state_dict = emb.state_dict()
858
                paddle.save(state_dict, "paddle_dy.pdparams")
859

860
                para_state_dict = paddle.load("paddle_dy.pdparams")
861
                emb.set_state_dict(para_state_dict)
862 863 864

        '''

865 866
        self._layers.set_state_dict(state_dict,
                                    use_structured_name=use_structured_name)
867 868 869 870

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict