parallel.py 43.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
qizhaoaoe 已提交
15
import itertools
16
import os
17
import sys
18
import time
19
import warnings
20
from collections import OrderedDict, namedtuple
Q
qizhaoaoe 已提交
21
from contextlib import contextmanager
22
from multiprocessing import Manager  # noqa: F401
23 24
from multiprocessing import Process  # noqa: F401

Q
qizhaoaoe 已提交
25 26
import numpy as np

27
import paddle
Q
qizhaoaoe 已提交
28
from paddle import _legacy_C_ops, framework
29 30 31 32 33 34 35 36 37 38 39 40
from paddle.distributed.collective import (
    Group,
    _default_group_name,
    _get_group_map_by_name,
    _new_process_group_impl,
    _set_default_backend,
    _set_default_store,
    _set_group_map,
    _set_group_map_backend,
    _set_group_map_by_name,
    _valid_backend_list,
)
41 42 43 44 45
from paddle.distributed.communication.group import (
    _add_new_group,
    _get_global_group,
    is_initialized,
)
46 47 48 49
from paddle.distributed.fleet.base.private_helper_function import (  # noqa: F401
    wait_server_ready,
)
from paddle.distributed.fleet.launch_utils import check_backend
50

51
# (TODO: GhostScreaming) It will be removed later.
W
wanghuancoder 已提交
52
from paddle.framework import _set_expected_place
Q
qizhaoaoe 已提交
53
from paddle.framework import base as imperative_base
54
from paddle.framework import core, in_dynamic_mode
55
from paddle.nn.layer import layers
Q
qizhaoaoe 已提交
56
from paddle.utils import deprecated
57

Q
qizhaoaoe 已提交
58
from . import parallel_helper
59

60
__all__ = []
61 62 63

ParallelStrategy = core.ParallelStrategy

Q
qizhaoaoe 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = paddle.distributed.ParallelEnv().nranks
    strategy.local_rank = paddle.distributed.ParallelEnv().local_rank
    strategy.trainer_endpoints = (
        paddle.distributed.ParallelEnv().trainer_endpoints
    )
    strategy.current_endpoint = (
        paddle.distributed.ParallelEnv().current_endpoint
    )
    return strategy


def _coalesce_tensors(var_groups):
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
                paddle.reshape(x=g_var, shape=[np.prod(g_var.shape)])
            )
        coalesced_grad = paddle.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes]
        )
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
97
    x_shape = framework._create_tensor(dtype=x.dtype)
Q
qizhaoaoe 已提交
98 99 100 101 102 103 104 105 106 107
    framework._dygraph_tracer().trace_op(
        type="reshape2",
        inputs={'X': x},
        outputs={'Out': x, 'XShape': x_shape},
        attrs={'shape': shape},
    )


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
108
    if in_dynamic_mode():
Q
qizhaoaoe 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
        for (
            coalesced_grad,
            origin_grad_vars,
            grad_shapes,
        ) in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
            attrs = ()
            attrs += ('sections', grad_var_len)
            attrs += ('axis', 0)
            _legacy_C_ops.split(coalesced_grad, origin_grad_vars, *attrs)
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
                g_var.reshape_(shape=g_shape)
                assert g_var.shape == g_shape


@imperative_base.no_grad
@framework.dygraph_only
def build_groups(vars, group_size):
    group_idx = 0
    memory_counter = 0
    var_groups = OrderedDict()
    dtype = vars[0].dtype

    for var in vars:
        bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
        if memory_counter < group_size and dtype == var.dtype:
            memory_counter += bytes
        else:
            memory_counter = bytes
            dtype = var.dtype
            group_idx += 1
        var_groups.setdefault(group_idx, []).append(var)
    return _coalesce_tensors(var_groups)


@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(
    model, comm_group=None, src_rank=0, is_model_parallel=False
):
    model_vars = []
    for _, param in model._obtain_parameters_buffers().items():
W
wanghuancoder 已提交
151
        if not isinstance(param, core.eager.Tensor):
Q
qizhaoaoe 已提交
152
            raise TypeError(
153
                "The data type of '%s' must be core.eager.Tensor" % param.name
Q
qizhaoaoe 已提交
154 155
            )

156 157
        if is_model_parallel:
            if hasattr(param, "is_distributed") and param.is_distributed:
Q
qizhaoaoe 已提交
158
                continue
159 160 161 162 163 164

        # NOTE(shenliang03): Support situations that do not require synchronization parameters,
        # such as moe's expert parameters
        if getattr(param, "no_sync", False):
            continue

Q
qizhaoaoe 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
        if param.type == core.VarDesc.VarType.VOCAB:
            continue

        model_vars.append(param.detach())
    if len(model_vars) == 0:
        return

    # group size is 128M
    coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)

    for coalesced_var, _, _ in coalesced_vars:
        paddle.distributed.broadcast(
            coalesced_var, src=src_rank, group=comm_group, sync_op=True
        )

    for coalesced_var, origin_vars, var_shapes in coalesced_vars:
        var_len = [np.prod(v_shape) for v_shape in var_shapes]
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_var},
            outputs={'Out': origin_vars},
            attrs={'sections': var_len, 'axis': 0},
        )


class DataParallel(layers.Layer):
    """
    Run the dygraph module with data parallelism.

    Currently, DataParallel class only supports to run the dynamic graph
    with multi-process.

    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)

    2. start by ``paddle.distributed.launch`` module, for example:

        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .

    And the content of `demo.py` is the code of examples.

    Args:
        layers(Layer): The module that should be executed by data parallel.
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism,
            contains environment configuration related to parallel execution. Default: None.
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer
                                          parameters' gradient which is the input of communication
                                          calling(e.g NCCLAllReduce). Default: 25.
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to
                                         improve performance. Default: 1.
        find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
                                                all tensors in the return value of the wrapped model's
                                                forward function. For parameters not involved in loss
                                                calculation, their gradients will be marked as ready in
                                                advance to prepare reduce. Please note that all forward
                                                outputs derived from the wrapped model parameters must
                                                participate in the calculation of loss and subsequent
                                                gradient calculations. If not, serious error will occur.
                                                Note that setting the find_unused_parameters to True
                                                will affect computing performance. Therefore, if all parameters
                                                are sure to participate in the loss calculation and the
                                                autograd graph construction, please set it False. Default: False.

    Returns:
        Layer: The data paralleled module.

    Examples:

        .. code-block:: python
            :name: dp-example

            # required: distributed
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super().__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)

                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
                # 1. initialize parallel environment
                dist.init_parallel_env()

                # 2. create data parallel layer & optimizer
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

                # 3. run layer
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)

                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()


    .. note::
        ``PyLayer`` is not supported in DataParallel. To solve problems of this kind,
        it's recommended to skip gradient synchronization among multiple cards by 'no_sync',
        and manually implement 'all_reduce' before model optimization. There is an example
        showing specific implemetation processing.

    Examples:

        .. code-block:: python
            :name: dp-pylayer-example

            # required: distributed
            import numpy
            import paddle
            import paddle.distributed as dist
            from paddle.autograd import PyLayer
            from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

            class cus_tanh(PyLayer):
                @staticmethod
                def forward(ctx, x):
                    y = paddle.tanh(x)
                    ctx.save_for_backward(y)
                    return y

                @staticmethod
                def backward(ctx, dy):
                    y, = ctx.saved_tensor()
                    grad = dy * (1 - paddle.square(y))
                    return grad

            class SimpleNet(paddle.nn.Layer):
                def __init__(self):
                    super().__init__()
                    self.linear = paddle.nn.Linear(2, 2)

                def forward(self, inputs):
                    inputs = cus_tanh.apply(inputs)
                    return self.linear(inputs)

            if __name__ == '__main__':
                dist.init_parallel_env()

                model = SimpleNet()
                model = paddle.DataParallel(model)
                opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())

                for step in range(10):
                    x_data = numpy.random.randn(2, 2).astype(numpy.float32)
                    x = paddle.to_tensor(x_data)
                    x.stop_gradient = False

                    # step 1 : skip gradient synchronization by 'no_sync'
                    with model.no_sync():
                        y_pred = model(x)
                        loss = y_pred.mean()
                        loss.backward()

                    # step 2 : fuse + allreduce manually before optimization
                    fused_allreduce_gradients(list(model.parameters()), None)

                    opt.step()
                    opt.clear_grad()

    """

    def __init__(
        self,
        layers,
        strategy=None,
        comm_buffer_size=25,
        last_comm_buffer_size=1,
        find_unused_parameters=False,
        group=None,
    ):
        super().__init__(layers.full_name() + "_data_parallel")

        assert (
363
            in_dynamic_mode()
Q
qizhaoaoe 已提交
364 365 366 367 368 369
        ), "It's not supported to construct DataParallel in static graph mode."

        self._layers = layers
        self.find_unused_parameters = find_unused_parameters
        self.grad_need_sync = True
        self.group = group
W
wanghuancoder 已提交
370
        self.var_dtype = core.eager.Tensor
Q
qizhaoaoe 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387

        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
        # It just stores some environment variables, which can be constructed by
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
            self._strategy = _build_default_parallel_strategy()

        if self._strategy.nranks > 1:
            # check the environment
            assert parallel_helper.__parallel_ctx__clz__ is not None, (
                "ParallelContext must be initialized before. You should use init_parallel_env() before"
                "constructing the DataParallel."
            )

388
            if in_dynamic_mode():
Q
qizhaoaoe 已提交
389 390 391 392 393 394 395 396 397 398 399
                self.group = (
                    paddle.distributed.collective._get_default_group()
                    if self.group is None
                    else self.group
                )

                assert isinstance(
                    self.group, paddle.distributed.collective.Group
                ), "ProcessGroup must be an instance of Group in DataParallel."

            # sync buffer and params
400
            sync_params_buffers(self._layers)
Q
qizhaoaoe 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428

            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
            # NOTE(shenliang03): We can set environment variables to control
            # the size of the group, Default: 1MB. The role of this small group is:
            # when the last group allreduce, the overlap cannot work. Making the
            # the last group small is useful to improve performance.
            self.last_comm_buffer_size = int(
                last_comm_buffer_size * 1024 * 1024
            )
            self.init_reducer()
        else:
            warnings.warn(
                "The program will return to single-card operation. "
                "Please check 1, whether you use spawn or fleetrun "
                "to start the program. 2, Whether it is a multi-card "
                "program. 3, Is the current environment multi-card."
            )

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                if not isinstance(param, self.var_dtype):
                    raise TypeError(
429
                        f"The data type of '{param.name}' must be '{self.var_dtype}'"
Q
qizhaoaoe 已提交
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
                    )
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = list(
            filter(
                lambda x: not getattr(x, "no_sync", False),
                [param for _, param in layers_param],
            )
        )

        assert len(trainable_parameters) > 0, (
            "This model does not have any parameters to train, and "
            "does not need to use DataParallel"
        )

        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
        def check_layer_sparse(sublayer):
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

459
        if in_dynamic_mode():
Q
qizhaoaoe 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
            self.group_indices = core.eager_assign_group_by_size(
                trainable_parameters,
                is_sparse_gradient,
                [self.last_comm_buffer_size, self.comm_buffer_size],
            )

            self._reducer = core.EagerReducer(
                trainable_parameters,
                list(reversed(self.group_indices)),
                is_sparse_gradient,
                self.group.process_group,
                [self.last_comm_buffer_size, self.comm_buffer_size],
                self.find_unused_parameters,
            )

475
    def _find_tensor(self, obj):
W
wanghuancoder 已提交
476
        var_type = core.eager.Tensor
Q
qizhaoaoe 已提交
477 478 479
        if isinstance(obj, var_type):
            return [obj]
        if isinstance(obj, (list, tuple)):
480
            return itertools.chain(*map(self._find_tensor, obj))
Q
qizhaoaoe 已提交
481
        if isinstance(obj, dict):
482
            return itertools.chain(*map(self._find_tensor, obj.values()))
Q
qizhaoaoe 已提交
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
        return []

    @contextmanager
    def no_sync(self):
        """
        A context manager to stop gradient synchronization. Within no_sync(),
        gradients of parameters will only be accumulated on model and not
        synchronized util the first forward-backward out of this context.

        Examples:
            .. code-block:: python

                # required: distributed
                import paddle
                import paddle.nn as nn
                import paddle.distributed as dist

                class SimpleNet(nn.Layer):
                    def __init__(self):
                        super().__init__()
                        self._linear = nn.Linear(10, 1)

                    def forward(self, x):
                        return self._linear(x)

                dist.init_parallel_env()
                model = SimpleNet()
                dp_model = paddle.DataParallel(model)

                inputs_1 = paddle.randn([10, 10], 'float32')
                inputs_2 = paddle.ones([10, 10], 'float32')

                with dp_model.no_sync():
                    # gradients will not be synchronized
                    dp_model(inputs_1).backward()

                # synchronization happens here
                dp_model(inputs_2).backward()

        """
        tmp_grad_need_sync = self.grad_need_sync
        self.grad_need_sync = False
        try:
            yield
        finally:
            self.grad_need_sync = tmp_grad_need_sync

    def forward(self, *inputs, **kwargs):
        outputs = self._layers(*inputs, **kwargs)
        if (
            self._strategy.nranks > 1
            and framework._dygraph_tracer()._has_grad
            and self.grad_need_sync
        ):
537
            self._reducer.prepare_for_backward(list(self._find_tensor(outputs)))
Q
qizhaoaoe 已提交
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
        return outputs

    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore."
    )
    def scale_loss(self, loss):
        """
        Deprecated method, now ``scale_loss`` is an empty method,
        keep this method just for compatibility.
        """
        return loss

    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore."
    )
    def apply_collective_grads(self):
        """
        Deprecated method, now ``apply_collective_grads`` is an empty method,
        keep this method just for compatibility.
        """
        return

    def state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
    ):
        '''
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True

        Retruns:
            dict: a dict contains all the parameters and persistable buffers.

        Examples:
            .. code-block:: python

                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = paddle.nn.Embedding(10, 10)
                emb = paddle.DataParallel(emb)

                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
        )

    @framework.deprecate_stat_dict
    def set_state_dict(self, state_dict, use_structured_name=True):
        '''
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict

        Parameters:
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = paddle.nn.Embedding(10, 10)
                emb = paddle.DataParallel(emb)

                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")

                para_state_dict = paddle.load("paddle_dy.pdparams")
                emb.set_state_dict(para_state_dict)

        '''

        self._layers.set_state_dict(
            state_dict, use_structured_name=use_structured_name
        )

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict


638
# NOTE(chenweihang): Maintain a global parallel env to avoid
639 640 641 642
# initializing ParallelEnv every time and improve performance
_global_parallel_env = None


643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849
class ParallelEnv:
    """
    .. note::
        This API is not recommended, if you need to get rank and world_size,
        it is recommended to use ``paddle.distributed.get_rank()`` and
        ``paddle.distributed.get_world_size()`` .

    This class is used to obtain the environment variables required for
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.

    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
    or ``paddle.distributed.spawn`` .

    Examples:
      .. code-block:: python

        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
    """

    def __init__(self):
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
        self._device_type = str(os.getenv("PADDLE_XCCL_BACKEND", ""))

        # imperative only support one gpu or xpu
        if self._device_type != "":
            FLAGS_selected_custom_devices = 'FLAGS_selected_{}s'.format(
                self._device_type
            )
            selected_custom_devices = os.getenv(
                FLAGS_selected_custom_devices, "0"
            ).split(",")
            self._device_id = int(selected_custom_devices[0])
        else:
            if core.is_compiled_with_cuda():
                selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
                self._device_id = int(selected_gpus[0])
            elif core.is_compiled_with_xpu():
                selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
                self._device_id = int(selected_xpus[0])

        self._trainer_endpoints = os.getenv(
            "PADDLE_TRAINER_ENDPOINTS", ""
        ).split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert (
            self._nrings > 0
        ), "nccl_nrings must be an integer greater than 0."
        assert (
            self._nrings < 9
        ), "nccl_nrings should be less than 9, which is enough in most scenarios."

    @property
    def rank(self):
        """
        Rank of current trainer.

        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist

            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
        """
        return self._rank

    @property
    def world_size(self):
        """
        The number of trainers (number of processes participating in current job).

        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist

            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
        """
        return self._world_size

    @property
    def device_id(self):
        """
        The ID of selected GPU card for parallel training.

        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
            import paddle.distributed as dist

            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
            # The device id are 1
        """
        return self._device_id

    @property
    def device_type(self):
        """
        The type of custom device for parallel training.

        Its value is equal to the value of the environment variable ``PADDLE_XCCL_BACKEND`` . The default value is None.

        """
        return self._device_type

    @property
    def current_endpoint(self):
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
            import paddle.distributed as dist

            env = dist.ParallelEnv()
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
        return self._current_endpoint

    @property
    def trainer_endpoints(self):
        """
        The endpoints of all trainer nodes in the task,
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
            import paddle.distributed as dist

            env = dist.ParallelEnv()
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
        return self._trainer_endpoints

    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist

            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id


850 851 852 853 854 855
def _get_global_parallel_env():
    global _global_parallel_env
    if _global_parallel_env is None:
        _global_parallel_env = ParallelEnv()
    return _global_parallel_env

856

857
def _start_kv_server(port, http_server_d, size):
858
    from paddle.distributed.fleet.utils.http_server import KVServer
859

860
    http_server = KVServer(int(port), size=size)
861
    http_server.start()
862
    wait_seconds = 3
L
lilong12 已提交
863
    while http_server_d.get("running", False) or not http_server.should_stop():
864 865 866 867
        time.sleep(wait_seconds)
    http_server.stop()


X
xiongkun 已提交
868 869
def _is_cpuonly(backend):
    check_backend(backend)
870
    if (
K
Kim Yann 已提交
871
        backend in ['auto', 'nccl', 'bkcl', 'heter']
K
Kim Yann 已提交
872
        and (core.is_compiled_with_cuda() or core.is_compiled_with_xpu())
873
    ) or backend == 'xccl':
874 875 876 877 878 879
        # passes 'auto' and can use cuda or xpu, use the default logics. so return False
        return False
    else:
        return True


K
kuizhiqing 已提交
880 881 882
def _check_var_exists(var_name):
    var = os.environ.get(var_name, None)
    if var is None:
883 884 885 886
        raise ValueError(
            "paddle.distributed initialize error, "
            "environment variable %s is needed, but not set." % var_name
        )
K
kuizhiqing 已提交
887 888


889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913
def _get_modified_flags():
    ret = []
    FLAGS = namedtuple('FLAGS', ['name', 'current_value', 'default_value'])
    global_flags = core.globals()
    for key in global_flags.keys():
        value = global_flags.get(key)
        default_value = global_flags.get_default(key)
        if not value == default_value:
            ret.append(FLAGS(key, value, default_value))
    return ret


def _print_modified_flags(modified_flags):
    if len(modified_flags) > 0:
        sys.stderr.write(
            "======================= Modified FLAGS detected =======================\n"
        )
        for flag in modified_flags:
            sys.stderr.write(str(flag))
            sys.stderr.write("\n")
        sys.stderr.write(
            "=======================================================================\n"
        )


X
xiongkun 已提交
914
def init_parallel_env():
915
    """
916

917
    Initialize parallel training environment in dynamic graph mode.
918

919
    Note:
920
        Now initialize both `NCCL` and `GLOO` contexts for communication.
921

922 923 924 925 926
    Args:
        backend (string): A string represents the backend used by DataParallel,
            should be one of 'gloo'(for cpu), 'nccl'(for cuda), 'bkcl'(for xpu), 'auto'(auto detect).
            The auto detection prefer 'nccl', 'bkcl' than 'gloo'.

927 928
    Returns:
        None
929

930 931
    Examples:
        .. code-block:: python
932

933
            # required: gpu
934 935 936 937 938 939 940
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
941
                    super().__init__()
942 943
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
944

945 946 947 948
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
949
                # 1. initialize parallel environment
950 951
                dist.init_parallel_env()

952
                # 2. create data parallel layer & optimizer
953 954 955 956 957 958 959
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

960
                # 3. run layer
961 962 963 964
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
965

966 967 968 969 970 971 972
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                dist.spawn(train)
973

974 975
    """

976 977 978
    modified_flags = _get_modified_flags()
    _print_modified_flags(modified_flags)

979 980 981 982 983 984 985 986 987 988 989
    # 0. get env & check world size
    global _global_parallel_env
    # when call init_parallel_env, need update `_global_parallel_env`
    _global_parallel_env = ParallelEnv()
    parallel_env = _global_parallel_env
    # if not parallel, `init_parallel_env` do nothing
    if parallel_env.world_size < 2:
        warnings.warn(
            "Currently not a parallel execution environment, `paddle.distributed.init_parallel_env` will not do anything."
        )
        return
990
    # NOTE(xiongkun): support cpu gloo only, add this environment variable to
991
    #                 enable cpu only gloo prarllel training)
X
xiongkun 已提交
992 993
    backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto')
    is_cpu_only = _is_cpuonly(backend)
994
    # 1. gpu xpu check, must be gpu or xpu,
995 996 997 998
    if not (
        is_cpu_only
        or core.is_compiled_with_cuda()
        or core.is_compiled_with_xpu()
S
shentanyue 已提交
999
        or backend == "xccl"
1000
    ):
1001
        raise NotImplementedError(
1002 1003
            "If you want to use CPU-only version, please use 'gloo' as backend"
        )
1004

1005 1006
    if backend == "xccl":
        FLAGS_selected_custom_devices = 'FLAGS_selected_{}s'.format(
1007 1008
            parallel_env.device_type
        )
1009 1010 1011 1012 1013 1014 1015 1016
        _check_var_exists(FLAGS_selected_custom_devices)
    else:
        if not is_cpu_only and core.is_compiled_with_cuda():
            _check_var_exists("FLAGS_selected_gpus")
            backend = "nccl" if backend == "auto" else backend
        elif not is_cpu_only and core.is_compiled_with_xpu():
            _check_var_exists('FLAGS_selected_xpus')
            backend = "bkcl" if backend == "auto" else backend
1017

1018 1019 1020 1021
    _check_var_exists("PADDLE_TRAINER_ID")
    _check_var_exists("PADDLE_CURRENT_ENDPOINT")
    _check_var_exists("PADDLE_TRAINERS_NUM")

1022 1023 1024 1025 1026 1027
    # NOTE(chenweihang): [ why config global place here? ]
    # the dygraph mode will be set to default mode,
    # users will not call `dygraph.guard` or `enable_dygraph`
    # directly, if they want to switch default place,
    # they need to call a function to change default place,
    # here just set correctly place to users
1028
    if backend == "xccl":
1029 1030 1031
        place = core.CustomPlace(
            parallel_env.device_type, parallel_env.device_id
        )
1032
    elif is_cpu_only:
1033 1034 1035 1036 1037 1038 1039 1040
        place = core.CPUPlace()
    elif core.is_compiled_with_cuda():
        place = core.CUDAPlace(parallel_env.device_id)
    elif core.is_compiled_with_xpu():
        place = core.XPUPlace(parallel_env.device_id)
    _set_expected_place(place)

    group = None
1041

1042
    if backend in _valid_backend_list and in_dynamic_mode():
L
lilong12 已提交
1043 1044 1045
        if _default_group_name in _get_group_map_by_name():
            return _get_group_map_by_name()[_default_group_name]
        _set_default_backend(backend)
1046 1047 1048 1049 1050
        rank = int(os.getenv("PADDLE_TRAINER_ID"))
        world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))
        assert rank >= 0 and world_size > rank and world_size > 1, (
            "rank must be non-negative and world_size must be the "
            "maximum rank plus one. Moreover, at least two processes are "
1051 1052
            "required to create a process group."
        )
1053 1054
        master_addr = os.getenv("MASTER_ADDR", None)
        master_port = os.getenv("MASTER_PORT", None)
1055 1056 1057 1058 1059
        endpoints = (
            ":".join([master_addr, master_port])
            if master_addr and master_port
            else None
        )
1060
        if endpoints is None:
1061 1062 1063 1064 1065 1066 1067
            endpoints = os.getenv("PADDLE_MASTER", None)
        if endpoints is None:
            endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
        assert endpoints, (
            "The environment variable 'MASTER_ADDR' and 'MASTER_PORT' "
            "must be specified, for example 'export MASTER_ADDR=127.0.0.1' "
            "and 'export MASTER_ADDR=54612'. Or you can start your training"
1068 1069
            "with paddle.distributed.run module."
        )
1070 1071 1072
        master_addr, master_port = endpoints.split(":")
        master_port = int(master_port)
        is_master = rank == 0
1073
        stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
1074 1075 1076 1077 1078 1079 1080
        default_store = core.TCPStore(
            master_addr,
            master_port,
            is_master,
            world_size,
            timeout=stop_check_timeout,
        )
L
lilong12 已提交
1081
        _set_default_store(default_store)
1082 1083 1084 1085 1086 1087 1088 1089
        pg = _new_process_group_impl(
            backend,
            default_store,
            rank,
            world_size,
            _default_group_name,
            pg_options=None,
        )
1090
        ranks = list(range(world_size))
1091
        group = Group(rank, 0, ranks, pg=pg, name=_default_group_name)
L
lilong12 已提交
1092 1093
        _set_group_map_by_name(_default_group_name, group)
        _set_group_map(0, group)
1094
        _set_group_map_backend(group, backend)
1095
        _add_new_group(group)
1096
        parallel_helper._set_parallel_ctx(True)
1097 1098

        paddle.distributed.barrier(group=group)
1099 1100
        return group

1101
    node_num = {i.split(":")[0] for i in parallel_env.trainer_endpoints}
1102
    # 3: init gloo context (step 1: httpsever start)
L
lilong12 已提交
1103
    init_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
K
kuizhiqing 已提交
1104
    if is_cpu_only or init_gloo or backend == "heter":
L
lilong12 已提交
1105 1106 1107 1108 1109 1110 1111 1112
        ep_rank_0 = parallel_env.trainer_endpoints[0].split(":")
        manager = Manager()
        # glboal dict to store status
        http_server_d = manager.dict()
        http_server_d["running"] = False
        if parallel_env.rank == 0:
            # The scope for worker used by http server is '_worker'
            size = {'_worker': parallel_env.world_size}
K
kuizhiqing 已提交
1113 1114
            if backend == "heter":
                size = {'_worker': len(node_num)}
1115 1116 1117 1118
            http_server = Process(
                target=_start_kv_server,
                args=(int(ep_rank_0[1]), http_server_d, size),
            )
L
lilong12 已提交
1119 1120 1121
            http_server.daemon = True
            http_server_d["running"] = True
            http_server.start()
1122 1123

    # 4. init NCCL ParallelStrategy
1124
    strategy = ParallelStrategy()
1125 1126
    if parallel_helper._is_parallel_ctx_initialized():
        warnings.warn("The parallel environment has been initialized.")
1127 1128 1129 1130
    strategy.nranks = parallel_env.world_size
    strategy.local_rank = parallel_env.rank
    strategy.trainer_endpoints = parallel_env.trainer_endpoints
    strategy.current_endpoint = parallel_env.current_endpoint
1131
    strategy.nrings = parallel_env.nrings
1132

K
Kim Yann 已提交
1133
    # init nccl or bkcl or heter context
1134 1135
    if is_cpu_only:
        parallel_helper._set_parallel_ctx(
1136 1137 1138
            core.GLOOParallelContext(strategy, place)
        )
    elif backend == "heter":
K
kuizhiqing 已提交
1139
        parallel_helper._set_parallel_ctx(
1140 1141
            core.HeterParallelContext(strategy, parallel_env.device_id)
        )
1142
    elif core.is_compiled_with_cuda():
1143
        parallel_helper._set_parallel_ctx(
1144 1145
            core.NCCLParallelContext(strategy, place)
        )
1146 1147
    elif core.is_compiled_with_xpu():
        parallel_helper._set_parallel_ctx(
1148 1149
            core.BKCLParallelContext(strategy, place)
        )
K
Kim Yann 已提交
1150

K
kuizhiqing 已提交
1151 1152 1153 1154 1155
    if backend != "heter":
        other_endpoints = strategy.trainer_endpoints[:]
        other_endpoints.remove(strategy.current_endpoint)
        if not is_cpu_only and strategy.local_rank == 0:
            wait_server_ready(other_endpoints)
1156

1157
    parallel_helper._init_parallel_ctx()
K
kuizhiqing 已提交
1158

1159 1160 1161 1162
    # 5: init gloo context (step 2: gloo init)
    # dividing init_gloo into two part beacause nccl and gloo
    # are separately looking for free ports which sometimes
    # leads to port-conflict.
K
kuizhiqing 已提交
1163
    if (is_cpu_only or backend == "heter") and parallel_env.rank == 0:
1164
        # compare to init_gloo, we don't need to
1165 1166 1167
        # init gloo, because we do this in _init_parallel_ctx;
        http_server_d["running"] = False
        http_server.join()
L
lilong12 已提交
1168

1169 1170
    elif init_gloo:
        wait_server_ready([parallel_env.trainer_endpoints[0]])
L
lilong12 已提交
1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184
        gloo_strategy = core.GlooParallelStrategy()
        gloo_strategy.rank = parallel_env.rank
        gloo_strategy.rank_num = parallel_env.world_size
        gloo_strategy.ip_address = ep_rank_0[0]
        gloo_strategy.ip_port = int(ep_rank_0[1])
        default_init_timeout_seconds = 3600
        default_run_timeout_seconds = 9999999
        gloo_strategy.init_seconds = default_init_timeout_seconds
        gloo_strategy.run_seconds = default_run_timeout_seconds
        gloo = core.GlooParallelContext(gloo_strategy)
        gloo.init()
        if parallel_env.rank == 0:
            http_server_d["running"] = False
            http_server.join()
1185
    return group
1186

1187

L
LiYuRio 已提交
1188
def get_rank(group=None):
1189
    """
L
LiYuRio 已提交
1190 1191
    Returns the rank of current trainer in the given group, ranks are consecutive integers in [0, ``world_size``).
    If none of the group is given, the global group will be used as default.
1192

L
LiYuRio 已提交
1193 1194
    Args:
        group (Group, optional): The communication group you want to get rank of current trainer, use global group as default if group is None.
1195 1196

    Returns:
L
LiYuRio 已提交
1197 1198 1199 1200
        (int) The rank of current trainer in the given group. Return -1 if the process is not part of the given group.

    Warning:
        Argument ``group`` only supports in dygraph mode.
1201 1202 1203 1204

    Examples:
        .. code-block:: python

L
LiYuRio 已提交
1205
            # Execute this script using distributed launch with one card configs.
1206 1207 1208
            import paddle
            import paddle.distributed as dist

L
LiYuRio 已提交
1209
            dist.init_parallel_env()
1210 1211 1212
            print("The rank is %d" % dist.get_rank())
            # The rank is 0
    """
1213
    if in_dynamic_mode() and group:
L
LiYuRio 已提交
1214 1215 1216
        return group.rank

    assert group is None, "Only support group argument in eager mode."
1217
    return _get_global_parallel_env().rank
1218 1219


L
LiYuRio 已提交
1220
def get_world_size(group=None):
1221
    """
L
LiYuRio 已提交
1222 1223
    Returns the number of trainers (number of processes participating in current job) in the given group.
    If none of the group is given, the global group will be used as default.
1224

L
LiYuRio 已提交
1225 1226
    Args:
        group (Group, optional): The communication group you want to check world size, use global group as default if group is None.
1227 1228

    Returns:
L
LiYuRio 已提交
1229 1230 1231 1232
        (int) The number of trainers in the given group. Return -1 if the process if not part of the given group.

    Warning:
        Argument ``group`` only supports in dygraph mode.
1233 1234 1235 1236

    Examples:
        .. code-block:: python

L
LiYuRio 已提交
1237
            # Execute this script using distributed launch with one card configs.
1238 1239 1240
            import paddle
            import paddle.distributed as dist

L
LiYuRio 已提交
1241
            dist.init_parallel_env()
1242
            print("The world_size is %d" % dist.get_world_size())
L
LiYuRio 已提交
1243
            # The world_size is 1
1244
    """
1245 1246 1247 1248
    if in_dynamic_mode() and (group is None):
        if is_initialized():
            group = _get_global_group()

1249
    if in_dynamic_mode() and group:
L
LiYuRio 已提交
1250 1251 1252
        return group.world_size

    assert group is None, "Only support group argument in eager mode."
1253
    return _get_global_parallel_env().world_size