parallel.py 21.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
20 21 22 23 24 25 26

from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
27
import warnings
28
import paddle
29
import itertools
30

31
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
32 33 34 35

ParallelStrategy = core.ParallelStrategy


36
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
37
def prepare_context(strategy=None):
38 39 40
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
41 42 43 44 45 46 47 48
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
49
    assert framework.in_dygraph_mode() is True, \
50
        "dygraph.prepare_context should be used with dygraph mode."
51
    place = framework._current_expected_place()
C
chengduo 已提交
52
    assert place is not None, \
53
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
54 55 56 57 58 59 60 61
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
            assert ("Only support CUDAPlace for now.")
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
62
    return strategy
63 64


65 66
class ParallelEnv(object):
    """
67 68 69 70
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
71 72

    This class is used to obtain the environment variables required for 
73
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
74

75
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
76
    or ``paddle.distributed.spawn`` .
77 78 79 80

    Examples:
      .. code-block:: python

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
105 106
    """

107
    def __init__(self):
108 109
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
110 111 112 113 114

        # imperative only support one gpu
        selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
        self._device_id = int(selected_gpus[0])

115 116 117
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
118 119 120 121 122
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert self._nrings > 0, \
            "nccl_nrings must be an integer greater than 0."
        assert self._nrings < 9, \
            "nccl_nrings should be less than 9, which is enough in most scenarios."
123 124

    @property
125
    def rank(self):
126
        """
127
        Rank of current trainer.
128

129
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
130 131 132 133

        Examples:
          .. code-block:: python

134 135
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
136
            
137 138 139
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
140
        """
141
        return self._rank
142 143

    @property
144
    def world_size(self):
145
        """
146
        The number of trainers (number of processes participating in current job).
147

148
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
149 150 151 152

        Examples:
          .. code-block:: python

153 154
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
155
            
156 157 158
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
159
        """
160
        return self._world_size
161 162

    @property
163
    def device_id(self):
164 165 166
        """
        The ID of selected GPU card for parallel training.

167
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
168 169 170 171 172

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
173
            import paddle.distributed as dist
174
            
175 176
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
177 178
            # The device id are 1
        """
179
        return self._device_id
180 181 182

    @property
    def current_endpoint(self):
183 184 185
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

186
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
187 188 189 190 191

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
192
            import paddle.distributed as dist
193
            
194
            env = dist.ParallelEnv()
195 196 197
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
198
        return self._current_endpoint
199 200 201

    @property
    def trainer_endpoints(self):
202 203 204 205
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

206
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
207 208 209 210 211

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
212
            import paddle.distributed as dist
213
            
214
            env = dist.ParallelEnv()
215 216 217
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
218 219
        return self._trainer_endpoints

220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
            
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

239 240 241 242 243
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

244

245 246 247 248 249 250
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
                nn.reshape(
                    x=g_var, shape=[np.prod(g_var.shape)]))
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
    framework._dygraph_tracer().trace_op(
        type="reshape2",
        inputs={'X': x},
        outputs={'Out': x,
                 'XShape': x_shape},
        attrs={'shape': shape})


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
    for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
        grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
        framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_grad},
            outputs={'Out': origin_grad_vars},
            attrs={'sections': grad_var_len,
                   'axis': 0})
        for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
            _reshape_inplace(x=g_var, shape=g_shape)
            assert g_var.shape == g_shape


def scale_loss(loss):
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


314
class DataParallel(layers.Layer):
C
chengduo 已提交
315
    """
316
    Run the dygraph module with data parallelism.
C
chengduo 已提交
317

318
    Currently, DataParallel class only supports to run the dynamic graph
319 320 321 322 323 324 325 326 327 328
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
329
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
330 331

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
332

333 334
    Args:
        layers(Layer): The module that should be executed by data parallel.
335 336
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
337
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer  
338 339
                                          parameters' gradient which is the input of communication 
                                          calling(e.g NCCLAllReduce). Default: 25.
340 341
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to 
342
                                         improve performance. Default: 1.
343
            
344 345 346
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
347 348 349
    Examples:
        .. code-block:: python

350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
365
                # 1. initialize parallel environment
366 367
                dist.init_parallel_env()

368
                # 2. create data parallel layer & optimizer
369 370 371 372 373 374 375
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

376
                # 3. run layer
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
C
chengduo 已提交
392 393
    """

394 395 396
    def __init__(self,
                 layers,
                 strategy=None,
397 398
                 comm_buffer_size=25,
                 last_comm_buffer_size=1):
399 400
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
401

402
        self._layers = layers
403 404 405 406 407 408 409 410

        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. 
        # It just stores some environment variables, which can be constructed by 
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
411
            self._strategy = _build_default_parallel_strategy()
412

413
        if self._strategy.nranks > 1:
414
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
415 416 417 418
            # NOTE(shenliang03): We can set environment variables to control 
            # the size of the group, Default: 1MB. The role of this small group is: 
            # when the last group allreduce, the overlap cannot work. Making the 
            # the last group small is useful to improve performance.
419 420
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
421 422
            self.init_reducer()
        else:
423 424
            warnings.warn("The program will return to single-card operation. "
                          "Please check 1, whether you use spawn or fleetrun "
425 426
                          "to start the program. 2, Whether it is a multi-card "
                          "program. 3, Is the current environment multi-card.")
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                if not isinstance(param, core.VarBase):
                    raise TypeError("The data type of '%s' must be Varbase" %
                                    param.name)
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
447
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
448
        def check_layer_sparse(sublayer):
449 450 451
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding 
452
            # is removed in the future, the check will also be removed here.
453
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
454 455 456 457 458 459 460 461 462
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

        self.group_indices = core.assign_group_by_size(
            trainable_parameters, is_sparse_gradient,
463
            [self.last_comm_buffer_size, self.comm_buffer_size])
464 465 466 467 468

        assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

469 470 471
        # TODO(shenliang03) "find_unused_vars" interface will be exposed in the future 
        # to handle control flow to process unused parameters
        find_unused_vars = True
472 473 474 475
        self._reducer = core.Reducer(
            trainable_parameters,
            list(reversed(self.group_indices)), is_sparse_gradient,
            parallel_helper.__parallel_ctx__clz__,
476 477 478 479 480 481 482 483 484 485 486
            [self.last_comm_buffer_size, self.comm_buffer_size],
            find_unused_vars)

    def _find_varbase(self, obj):
        if isinstance(obj, core.VarBase):
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
487

488
    def forward(self, *inputs, **kwargs):
489
        outputs = self._layers(*inputs, **kwargs)
490
        if self._strategy.nranks > 1:
491 492
            self._reducer.prepare_for_backward(
                list(self._find_varbase(outputs)))
493

494
        return outputs
Y
Yan Xu 已提交
495

496 497
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
498
    def scale_loss(self, loss):
C
chengduo 已提交
499
        """
500 501
        Deprecated method, now ``scale_loss`` is an empty method,  
        keep this method just for compatibility.
C
chengduo 已提交
502
        """
Y
Yan Xu 已提交
503 504
        return loss

505 506
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
507
    def apply_collective_grads(self):
C
chengduo 已提交
508
        """
509 510
        Deprecated method, now ``apply_collective_grads`` is an empty method, 
        keep this method just for compatibility.
C
chengduo 已提交
511
        """
512
        return
513 514 515 516 517 518

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
519
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
520 521

        Parameters:
522 523
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
524 525

        Retruns:
526
            dict: a dict contains all the parameters and persistable buffers.
527 528 529 530

        Examples:
            .. code-block:: python

531 532 533 534 535 536 537
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
538

539 540
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
541 542 543 544 545 546 547 548

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

549 550 551 552 553
    @framework.deprecate_stat_dict
    def set_state_dict(self,
                       state_dict,
                       include_sublayers=True,
                       use_structured_name=True):
554
        '''
555
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
556 557

        Parameters:
558 559 560
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            include_sublayers(bool, optional) : If true, also include the parameters and peresistable buffers from sublayers. Default: True
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
561 562 563 564 565 566 567
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

568 569
                import paddle
                import paddle.distributed as dist
570

571
                dist.init_parallel_env()
572

573
                emb = paddle.nn.Embedding(10, 10)
574
                emb = fluid.dygraph.DataParallel(emb)
575

576
                state_dict = emb.state_dict()
577
                paddle.save(state_dict, "paddle_dy.pdparams")
578

579
                para_state_dict = paddle.load("paddle_dy.pdparams")
580
                emb.set_state_dict(para_state_dict)
581 582 583

        '''

584 585
        self._layers.set_state_dict(
            state_dict,
586 587
            include_sublayers=include_sublayers,
            use_structured_name=use_structured_name)
588 589 590 591

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict