parallel.py 22.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
20 21 22 23 24 25 26

from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
27
import warnings
28
import paddle
29
import itertools
30

31
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
32 33 34 35

ParallelStrategy = core.ParallelStrategy


36
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
37
def prepare_context(strategy=None):
38 39 40
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
41 42 43 44 45 46 47 48
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
49
    assert framework.in_dygraph_mode() is True, \
50
        "dygraph.prepare_context should be used with dygraph mode."
51
    place = framework._current_expected_place()
C
chengduo 已提交
52
    assert place is not None, \
53
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
54 55 56 57
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
58 59 60
        elif isinstance(place, core.XPUPlace):
            parallel_helper._set_parallel_ctx(
                core.BKCLParallelContext(strategy, place))
61 62
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
63
            assert ("Only support CUDAPlace or XPUPlace for now.")
64
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
65
    return strategy
66 67


68 69
class ParallelEnv(object):
    """
70 71 72 73
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
74 75

    This class is used to obtain the environment variables required for 
76
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
77

78
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
79
    or ``paddle.distributed.spawn`` .
80 81 82 83

    Examples:
      .. code-block:: python

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
108 109
    """

110
    def __init__(self):
111 112
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
113

114 115 116 117 118 119 120
        # imperative only support one gpu or xpu
        if core.is_compiled_with_cuda():
            selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
            self._device_id = int(selected_gpus[0])
        elif core.is_compiled_with_xpu():
            selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
            self._device_id = int(selected_xpus[0])
121

122 123 124
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
125 126 127 128 129
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert self._nrings > 0, \
            "nccl_nrings must be an integer greater than 0."
        assert self._nrings < 9, \
            "nccl_nrings should be less than 9, which is enough in most scenarios."
130 131

    @property
132
    def rank(self):
133
        """
134
        Rank of current trainer.
135

136
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
137 138 139 140

        Examples:
          .. code-block:: python

141 142
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
143
            
144 145 146
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
147
        """
148
        return self._rank
149 150

    @property
151
    def world_size(self):
152
        """
153
        The number of trainers (number of processes participating in current job).
154

155
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
156 157 158 159

        Examples:
          .. code-block:: python

160 161
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
162
            
163 164 165
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
166
        """
167
        return self._world_size
168 169

    @property
170
    def device_id(self):
171 172 173
        """
        The ID of selected GPU card for parallel training.

174
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
175 176 177 178 179

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
180
            import paddle.distributed as dist
181
            
182 183
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
184 185
            # The device id are 1
        """
186
        return self._device_id
187 188 189

    @property
    def current_endpoint(self):
190 191 192
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

193
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
194 195 196 197 198

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
199
            import paddle.distributed as dist
200
            
201
            env = dist.ParallelEnv()
202 203 204
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
205
        return self._current_endpoint
206 207 208

    @property
    def trainer_endpoints(self):
209 210 211 212
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

213
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
214 215 216 217 218

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
219
            import paddle.distributed as dist
220
            
221
            env = dist.ParallelEnv()
222 223 224
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
225 226
        return self._trainer_endpoints

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
            
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

246 247 248 249 250
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

251

252 253 254 255 256 257
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
                nn.reshape(
                    x=g_var, shape=[np.prod(g_var.shape)]))
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
    framework._dygraph_tracer().trace_op(
        type="reshape2",
        inputs={'X': x},
        outputs={'Out': x,
                 'XShape': x_shape},
        attrs={'shape': shape})


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
    for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
        grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
        framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_grad},
            outputs={'Out': origin_grad_vars},
            attrs={'sections': grad_var_len,
                   'axis': 0})
        for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
            _reshape_inplace(x=g_var, shape=g_shape)
            assert g_var.shape == g_shape


def scale_loss(loss):
311
    # TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
312 313 314 315 316 317 318 319 320 321
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


322
class DataParallel(layers.Layer):
C
chengduo 已提交
323
    """
324
    Run the dygraph module with data parallelism.
C
chengduo 已提交
325

326
    Currently, DataParallel class only supports to run the dynamic graph
327 328 329 330 331 332 333 334 335 336
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
337
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
338 339

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
340

341 342
    Args:
        layers(Layer): The module that should be executed by data parallel.
343 344
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
345
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer  
346 347
                                          parameters' gradient which is the input of communication 
                                          calling(e.g NCCLAllReduce). Default: 25.
348 349
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to 
350
                                         improve performance. Default: 1.
351
            
352 353 354
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
355 356 357
    Examples:
        .. code-block:: python

358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
373
                # 1. initialize parallel environment
374 375
                dist.init_parallel_env()

376
                # 2. create data parallel layer & optimizer
377 378 379 380 381 382 383
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

384
                # 3. run layer
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
C
chengduo 已提交
400 401
    """

402 403 404
    def __init__(self,
                 layers,
                 strategy=None,
405 406
                 comm_buffer_size=25,
                 last_comm_buffer_size=1):
407 408
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
409

410
        self._layers = layers
411 412 413 414 415 416 417 418

        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. 
        # It just stores some environment variables, which can be constructed by 
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
419
            self._strategy = _build_default_parallel_strategy()
420

421
        if self._strategy.nranks > 1:
422
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
423 424 425 426
            # NOTE(shenliang03): We can set environment variables to control 
            # the size of the group, Default: 1MB. The role of this small group is: 
            # when the last group allreduce, the overlap cannot work. Making the 
            # the last group small is useful to improve performance.
427 428
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
429 430
            self.init_reducer()
        else:
S
ShenLiang 已提交
431 432
            warnings.warn("The program will return to single-card operation. "
                          "Please check 1, whether you use spawn or fleetrun "
433 434
                          "to start the program. 2, Whether it is a multi-card "
                          "program. 3, Is the current environment multi-card.")
435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                if not isinstance(param, core.VarBase):
                    raise TypeError("The data type of '%s' must be Varbase" %
                                    param.name)
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
455
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
456
        def check_layer_sparse(sublayer):
457 458 459
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding 
460
            # is removed in the future, the check will also be removed here.
461
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
462 463 464 465 466 467 468 469 470
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

        self.group_indices = core.assign_group_by_size(
            trainable_parameters, is_sparse_gradient,
471
            [self.last_comm_buffer_size, self.comm_buffer_size])
472 473 474 475 476

        assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

477 478 479
        # TODO(shenliang03) "find_unused_vars" interface will be exposed in the future 
        # to handle control flow to process unused parameters
        find_unused_vars = True
480 481 482 483
        self._reducer = core.Reducer(
            trainable_parameters,
            list(reversed(self.group_indices)), is_sparse_gradient,
            parallel_helper.__parallel_ctx__clz__,
484 485 486 487 488 489 490 491 492 493 494
            [self.last_comm_buffer_size, self.comm_buffer_size],
            find_unused_vars)

    def _find_varbase(self, obj):
        if isinstance(obj, core.VarBase):
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
495

496
    def forward(self, *inputs, **kwargs):
497
        outputs = self._layers(*inputs, **kwargs)
498
        if self._strategy.nranks > 1:
499 500
            self._reducer.prepare_for_backward(
                list(self._find_varbase(outputs)))
501

502
        return outputs
Y
Yan Xu 已提交
503

504 505
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
506
    def scale_loss(self, loss):
C
chengduo 已提交
507
        """
508 509
        Deprecated method, now ``scale_loss`` is an empty method,  
        keep this method just for compatibility.
C
chengduo 已提交
510
        """
Y
Yan Xu 已提交
511 512
        return loss

513 514
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
515
    def apply_collective_grads(self):
C
chengduo 已提交
516
        """
517 518
        Deprecated method, now ``apply_collective_grads`` is an empty method, 
        keep this method just for compatibility.
C
chengduo 已提交
519
        """
520
        return
521 522 523 524 525 526

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
527
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
528 529

        Parameters:
530 531
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
532 533

        Retruns:
534
            dict: a dict contains all the parameters and persistable buffers.
535 536 537 538

        Examples:
            .. code-block:: python

539 540 541 542 543 544 545
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
546

547 548
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
549 550 551 552 553 554 555 556

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

557 558 559 560 561
    @framework.deprecate_stat_dict
    def set_state_dict(self,
                       state_dict,
                       include_sublayers=True,
                       use_structured_name=True):
562
        '''
563
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
564 565

        Parameters:
566 567 568
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            include_sublayers(bool, optional) : If true, also include the parameters and peresistable buffers from sublayers. Default: True
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
569 570 571 572 573 574 575
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

576 577
                import paddle
                import paddle.distributed as dist
578

579
                dist.init_parallel_env()
580

581
                emb = paddle.nn.Embedding(10, 10)
582
                emb = fluid.dygraph.DataParallel(emb)
583

584
                state_dict = emb.state_dict()
585
                paddle.save(state_dict, "paddle_dy.pdparams")
586

587
                para_state_dict = paddle.load("paddle_dy.pdparams")
588
                emb.set_state_dict(para_state_dict)
589 590 591

        '''

592 593
        self._layers.set_state_dict(
            state_dict,
594 595
            include_sublayers=include_sublayers,
            use_structured_name=use_structured_name)
596 597 598 599

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict