parallel.py 19.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
20 21 22 23 24 25 26

from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
27 28
from paddle.fluid.dygraph import nn
import warnings
29

30
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
31 32 33 34

ParallelStrategy = core.ParallelStrategy


35
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
36
def prepare_context(strategy=None):
37 38 39
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
40 41 42 43 44 45 46 47
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
48
    assert framework.in_dygraph_mode() is True, \
49
        "dygraph.prepare_context should be used with dygraph mode."
50
    place = framework._current_expected_place()
C
chengduo 已提交
51
    assert place is not None, \
52
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
53 54 55 56 57 58 59 60
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
            assert ("Only support CUDAPlace for now.")
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
61
    return strategy
62 63


64 65
class ParallelEnv(object):
    """
66 67 68 69
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
70 71

    This class is used to obtain the environment variables required for 
72
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
73

74
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
75
    or ``paddle.distributed.spawn`` .
76 77 78 79

    Examples:
      .. code-block:: python

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
104 105
    """

106
    def __init__(self):
107 108
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
109 110 111 112 113

        # imperative only support one gpu
        selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
        self._device_id = int(selected_gpus[0])

114 115 116 117 118
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")

    @property
119
    def rank(self):
120
        """
121
        Rank of current trainer.
122

123
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
124 125 126 127

        Examples:
          .. code-block:: python

128 129
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
130
            
131 132 133
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
134
        """
135
        return self._rank
136 137

    @property
138
    def world_size(self):
139
        """
140
        The number of trainers (number of processes participating in current job).
141

142
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
143 144 145 146

        Examples:
          .. code-block:: python

147 148
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
149
            
150 151 152
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
153
        """
154
        return self._world_size
155 156

    @property
157
    def device_id(self):
158 159 160
        """
        The ID of selected GPU card for parallel training.

161
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
162 163 164 165 166

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
167
            import paddle.distributed as dist
168
            
169 170
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
171 172
            # The device id are 1
        """
173
        return self._device_id
174 175 176

    @property
    def current_endpoint(self):
177 178 179
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

180
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
181 182 183 184 185

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
186
            import paddle.distributed as dist
187
            
188
            env = dist.ParallelEnv()
189 190 191
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
192
        return self._current_endpoint
193 194 195

    @property
    def trainer_endpoints(self):
196 197 198 199
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

200
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
201 202 203 204 205

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
206
            import paddle.distributed as dist
207
            
208
            env = dist.ParallelEnv()
209 210 211
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
212 213
        return self._trainer_endpoints

214 215 216 217 218
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

219

220 221 222 223 224 225
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
                nn.reshape(
                    x=g_var, shape=[np.prod(g_var.shape)]))
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
    framework._dygraph_tracer().trace_op(
        type="reshape2",
        inputs={'X': x},
        outputs={'Out': x,
                 'XShape': x_shape},
        attrs={'shape': shape})


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
    for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
        grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
        framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_grad},
            outputs={'Out': origin_grad_vars},
            attrs={'sections': grad_var_len,
                   'axis': 0})
        for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
            _reshape_inplace(x=g_var, shape=g_shape)
            assert g_var.shape == g_shape


def scale_loss(loss):
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


289
class DataParallel(layers.Layer):
C
chengduo 已提交
290
    """
291
    Run the dygraph module with data parallelism.
C
chengduo 已提交
292

293
    Currently, DataParallel class only supports to run the dynamic graph
294 295 296 297 298 299 300 301 302 303
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
304
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
305 306

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
307

308 309
    Args:
        layers(Layer): The module that should be executed by data parallel.
310 311
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
312
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer  
313 314
                                          parameters' gradient which is the input of communication 
                                          calling(e.g NCCLAllReduce). Default: 25.
315 316
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to 
317
                                         improve performance. Default: 1.
318
            
319 320 321
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
322 323 324
    Examples:
        .. code-block:: python

325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
340
                # 1. initialize parallel environment
341 342
                dist.init_parallel_env()

343
                # 2. create data parallel layer & optimizer
344 345 346 347 348 349 350
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

351
                # 3. run layer
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
C
chengduo 已提交
367 368
    """

369 370 371
    def __init__(self,
                 layers,
                 strategy=None,
372 373
                 comm_buffer_size=25,
                 last_comm_buffer_size=1):
374 375
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
376

377
        self._layers = layers
378 379 380 381 382 383 384 385

        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. 
        # It just stores some environment variables, which can be constructed by 
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
386
            self._strategy = _build_default_parallel_strategy()
387

388
        if self._strategy.nranks > 1:
389
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
390 391 392 393
            # NOTE(shenliang03): We can set environment variables to control 
            # the size of the group, Default: 1MB. The role of this small group is: 
            # when the last group allreduce, the overlap cannot work. Making the 
            # the last group small is useful to improve performance.
394 395
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
            self.init_reducer()
        else:
            warnings.warn(
                "nranks is less than 2, "
                "maybe you need to check the current system environment."
                " Need to use spawn or fleetrun to "
                "start distributed programs.")

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                if not isinstance(param, core.VarBase):
                    raise TypeError("The data type of '%s' must be Varbase" %
                                    param.name)
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
        # we should add the layer here like "nn.Embedding".
        def check_layer_sparse(sublayer):
            if isinstance(sublayer, nn.Embedding):
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

        self.group_indices = core.assign_group_by_size(
            trainable_parameters, is_sparse_gradient,
435
            [self.last_comm_buffer_size, self.comm_buffer_size])
436 437 438 439 440 441 442 443 444 445

        assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

        self._reducer = core.Reducer(trainable_parameters,
                                     list(reversed(self.group_indices)),
                                     is_sparse_gradient,
                                     parallel_helper.__parallel_ctx__clz__)

446
    def forward(self, *inputs, **kwargs):
447 448 449
        if self._strategy.nranks > 1:
            self._reducer.prepare_for_backward()

Y
Yan Xu 已提交
450 451
        return self._layers(*inputs, **kwargs)

452 453
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
454
    def scale_loss(self, loss):
C
chengduo 已提交
455
        """
456 457
        Deprecated method, now ``scale_loss`` is an empty method,  
        keep this method just for compatibility.
C
chengduo 已提交
458
        """
Y
Yan Xu 已提交
459 460
        return loss

461 462
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
463
    def apply_collective_grads(self):
C
chengduo 已提交
464
        """
465 466
        Deprecated method, now ``apply_collective_grads`` is an empty method, 
        keep this method just for compatibility.
C
chengduo 已提交
467
        """
468
        return
469 470 471 472 473 474

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
475
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
476 477

        Parameters:
478 479
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
480 481

        Retruns:
482
            dict: a dict contains all the parameters and persistable buffers.
483 484 485 486

        Examples:
            .. code-block:: python

487 488 489 490 491 492 493
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
494

495 496
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
497 498 499 500 501 502 503 504

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

505 506 507 508 509
    @framework.deprecate_stat_dict
    def set_state_dict(self,
                       state_dict,
                       include_sublayers=True,
                       use_structured_name=True):
510
        '''
511
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
512 513

        Parameters:
514 515 516
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            include_sublayers(bool, optional) : If true, also include the parameters and peresistable buffers from sublayers. Default: True
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
517 518 519 520 521 522 523
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

524 525
                import paddle
                import paddle.distributed as dist
526

527
                dist.init_parallel_env()
528

529
                emb = paddle.nn.Embedding(10, 10)
530
                emb = fluid.dygraph.DataParallel(emb)
531

532
                state_dict = emb.state_dict()
533
                paddle.save(state_dict, "paddle_dy.pdparams")
534

535
                para_state_dict = paddle.load("paddle_dy.pdparams")
536
                emb.set_state_dict(para_state_dict)
537 538 539

        '''

540 541
        self._layers.set_state_dict(
            state_dict,
542 543
            include_sublayers=include_sublayers,
            use_structured_name=use_structured_name)
544 545 546 547

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict