parallel.py 22.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
20 21 22 23 24 25 26

from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
27

28
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
29 30 31 32

ParallelStrategy = core.ParallelStrategy


33
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
34
def prepare_context(strategy=None):
35 36 37
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
38 39 40 41 42 43 44 45
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
46
    assert framework.in_dygraph_mode() is True, \
47
        "dygraph.prepare_context should be used with dygraph mode."
48
    place = framework._current_expected_place()
C
chengduo 已提交
49
    assert place is not None, \
50
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
51 52 53 54 55 56 57 58
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
            assert ("Only support CUDAPlace for now.")
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
59
    return strategy
60 61


62 63
class ParallelEnv(object):
    """
64 65 66 67
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
68 69

    This class is used to obtain the environment variables required for 
70
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
71

72 73
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch`` 
    or ``paddle.distributed.spawn`` .
74 75 76 77

    Examples:
      .. code-block:: python

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
102 103
    """

104
    def __init__(self):
105 106 107
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
        self._device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
108 109 110 111 112
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")

    @property
113
    def rank(self):
114
        """
115
        Rank of current trainer.
116

117
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
118 119 120 121

        Examples:
          .. code-block:: python

122 123
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
124
            
125 126 127
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
128
        """
129
        return self._rank
130 131

    @property
132
    def world_size(self):
133
        """
134
        The number of trainers (number of processes participating in current job).
135

136
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
137 138 139 140

        Examples:
          .. code-block:: python

141 142
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
143
            
144 145 146
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
147
        """
148
        return self._world_size
149 150

    @property
151
    def device_id(self):
152 153 154
        """
        The ID of selected GPU card for parallel training.

155
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
156 157 158 159 160

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
161
            import paddle.distributed as dist
162
            
163 164
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
165 166
            # The device id are 1
        """
167
        return self._device_id
168 169 170

    @property
    def current_endpoint(self):
171 172 173
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

174
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
175 176 177 178 179

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
180
            import paddle.distributed as dist
181
            
182
            env = dist.ParallelEnv()
183 184 185
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
186
        return self._current_endpoint
187 188 189

    @property
    def trainer_endpoints(self):
190 191 192 193
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

194
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
195 196 197 198 199

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
200
            import paddle.distributed as dist
201
            
202
            env = dist.ParallelEnv()
203 204 205
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
206 207
        return self._trainer_endpoints

208 209 210 211 212
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

213

214 215 216 217 218 219
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


220
class DataParallel(layers.Layer):
C
chengduo 已提交
221
    """
222
    Run the dygraph module with data parallelism.
C
chengduo 已提交
223

224
    Currently, DataParallel class only supports to run the dynamic graph
225 226 227 228 229 230 231 232 233 234 235 236 237
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
        ``python -m paddle.distributed.launch --selected_gpus=0,1 demo.py`` .

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
238

239 240
    Args:
        layers(Layer): The module that should be executed by data parallel.
241 242 243
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
            
244 245 246
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
247 248 249
    Examples:
        .. code-block:: python

250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
                # 1. enable dynamic mode
                paddle.disable_static()
                
                # 2. initialize parallel environment
                dist.init_parallel_env()

                # 3. create data parallel layer & optimizer
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

                # 4. run layer
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss = dp_layer.scale_loss(loss)
                loss.backward()
                dp_layer.apply_collective_grads()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
C
chengduo 已提交
297 298
    """

299
    def __init__(self, layers, strategy=None):
300 301
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
302

303
        self._layers = layers
304 305 306 307 308 309 310 311 312 313 314 315 316

        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. 
        # It just stores some environment variables, which can be constructed by 
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
            self._strategy = ParallelStrategy()
            self._strategy.nranks = ParallelEnv().nranks
            self._strategy.local_rank = ParallelEnv().local_rank
            self._strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
            self._strategy.current_endpoint = ParallelEnv().current_endpoint
317 318

    def forward(self, *inputs, **kwargs):
Y
Yan Xu 已提交
319 320 321
        return self._layers(*inputs, **kwargs)

    def scale_loss(self, loss):
C
chengduo 已提交
322 323 324 325 326 327
        """
        Scale the loss. In data parallel mode, the loss should be scale with
        the number of trainers. If not in data parallel mode, return the loss
        directly.

        Args:
328
            loss(Variable): The loss of the current Model.
C
chengduo 已提交
329 330

        Returns:
331 332 333 334 335
            Variable: the scaled loss.

        Examples:
            .. code-block:: python

336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
                import paddle
                import paddle.nn as nn
                import paddle.optimizer as opt
                import paddle.distributed as dist

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
                        
                    def forward(self, x):
                        return self._linear2(self._linear1(x))

                def train():
                    # 1. enable dynamic mode
                    paddle.disable_static()
                    
                    # 2. initialize parallel environment
                    dist.init_parallel_env()

                    # 3. create data parallel layer & optimizer
                    layer = LinearNet()
                    dp_layer = paddle.DataParallel(layer)

                    loss_fn = nn.MSELoss()
                    adam = opt.Adam(
                        learning_rate=0.001, parameters=dp_layer.parameters())

                    # 4. run layer
                    inputs = paddle.randn([10, 10], 'float32')
                    outputs = dp_layer(inputs)
                    labels = paddle.randn([10, 1], 'float32')
                    loss = loss_fn(outputs, labels)
                    
                    loss = dp_layer.scale_loss(loss)
                    loss.backward()
                    dp_layer.apply_collective_grads()

                    adam.step()
                    adam.clear_grad()

                if __name__ == '__main__':
                    # 1. start by ``paddle.distributed.spawn`` (default)
                    dist.spawn(train, nprocs=2)
                    # 2. start by ``paddle.distributed.launch``
                    # train()
C
chengduo 已提交
383 384
        """
        if not self._is_data_parallel_mode():
Y
Yan Xu 已提交
385
            return loss
C
chengduo 已提交
386

Y
Yan Xu 已提交
387 388 389 390 391 392
        loss_scale = to_variable(
            np.array([self._strategy.nranks]).astype("float32"))
        loss_scale.stop_gradient = True
        loss = loss / loss_scale
        return loss

393 394 395 396 397 398 399 400 401 402
    def _coalesce_tensors(self, var_groups):
        from ..layers import nn
        coalesced_grads_and_grad_vars = []
        for group_id, grad_vars in var_groups.items():
            flattened_vars = []
            g_var_shapes = []
            for g_var in grad_vars:
                g_var_shapes.append(g_var.shape)
                flattened_vars.append(
                    nn.reshape(
403
                        x=g_var, shape=[np.prod(g_var.shape)]))
404 405 406 407 408
            coalesced_grad = nn.concat(flattened_vars)
            coalesced_grads_and_grad_vars.append(
                [coalesced_grad, grad_vars, g_var_shapes])
        return coalesced_grads_and_grad_vars

409 410 411 412 413 414 415 416 417
    def _reshape_inplace(self, x, shape):
        x_shape = self._helper.create_variable_for_type_inference(dtype=x.dtype)
        self._helper.append_op(
            type="reshape2",
            inputs={'X': x},
            attrs={'shape': shape},
            outputs={'Out': x,
                     'XShape': x_shape})

418 419 420 421
    def _split_tensors(self, coalesced_grads_and_grad_vars):
        from ..layers import nn
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
422 423 424 425 426 427 428
            self._helper.main_program.current_block().append_op(
                type='split',
                inputs={'X': coalesced_grad},
                outputs={'Out': origin_grad_vars},
                attrs={'sections': grad_var_len,
                       'axis': 0})
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
429 430
                self._reshape_inplace(x=g_var, shape=g_shape)
                assert g_var.shape == g_shape
431

432
    @no_grad
Y
Yan Xu 已提交
433
    def apply_collective_grads(self):
C
chengduo 已提交
434 435
        """
        AllReduce the Parameters' gradient.
436 437 438 439

        Examples:
            .. code-block:: python

440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
                import paddle
                import paddle.nn as nn
                import paddle.optimizer as opt
                import paddle.distributed as dist

                class LinearNet(nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__()
                        self._linear1 = nn.Linear(10, 10)
                        self._linear2 = nn.Linear(10, 1)
                        
                    def forward(self, x):
                        return self._linear2(self._linear1(x))

                def train():
                    # 1. enable dynamic mode
                    paddle.disable_static()
                    
                    # 2. initialize parallel environment
                    dist.init_parallel_env()

                    # 3. create data parallel layer & optimizer
                    layer = LinearNet()
                    dp_layer = paddle.DataParallel(layer)

                    loss_fn = nn.MSELoss()
                    adam = opt.Adam(
                        learning_rate=0.001, parameters=dp_layer.parameters())

                    # 4. run layer
                    inputs = paddle.randn([10, 10], 'float32')
                    outputs = dp_layer(inputs)
                    labels = paddle.randn([10, 1], 'float32')
                    loss = loss_fn(outputs, labels)
                    
                    loss = dp_layer.scale_loss(loss)
                    loss.backward()
                    dp_layer.apply_collective_grads()

                    adam.step()
                    adam.clear_grad()

                if __name__ == '__main__':
                    # 1. start by ``paddle.distributed.spawn`` (default)
                    dist.spawn(train, nprocs=2)
                    # 2. start by ``paddle.distributed.launch``
                    # train()
C
chengduo 已提交
487 488
        """
        if not self._is_data_parallel_mode():
Y
Yan Xu 已提交
489 490
            return

491 492
        grad_var_set = set()
        grad_vars = []
493
        sparse_grad_vars = []
Y
Yan Xu 已提交
494
        for param in self._layers.parameters():
C
chengduo 已提交
495
            # NOTE(zcd): The grad_ivar maybe no generated.
496
            if param.trainable and (param._grad_ivar() is not None):
497
                g_var = param._grad_ivar()
498 499 500
                if g_var._is_sparse():
                    sparse_grad_vars.append(g_var)
                    continue
501 502 503 504
                grad_vars.append(g_var)
                assert g_var not in grad_var_set
                grad_var_set.add(g_var)

505 506 507 508 509
        if sparse_grad_vars:
            sparse_grad_vars.sort(key=lambda x: x.name)
            for grad_var in sparse_grad_vars:
                grad_var._allreduce(self._strategy)

510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
        # FIXME(zcd): the type of the var should be LoDTensor, i.e
        # the gradients should be dense, otherwise, the following
        # logic should be updated.
        # 128 MB as a group
        mega_bytes = 128 * 1024 * 1024
        group_idx = 0
        memory_counter = 0
        grad_var_groups = OrderedDict()
        dtype = grad_vars[0].dtype
        for g_var in grad_vars:
            # Note: the dtype of the same group should be the same.
            bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype)
            if memory_counter < mega_bytes and dtype == g_var.dtype:
                memory_counter += bytes
            else:
                memory_counter = bytes
                group_idx += 1
            grad_var_groups.setdefault(group_idx, []).append(g_var)

        coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups)

531 532
        for coalesced_grad, _, _ in coalesced_grads_and_vars:
            coalesced_grad._allreduce(self._strategy)
533 534

        self._split_tensors(coalesced_grads_and_vars)
C
chengduo 已提交
535 536 537

    def _is_data_parallel_mode(self):
        return self._strategy.nranks > 1
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
        Get all parameters of self._layers and its sub-layers. And set all the parameters into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters will set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
            structured_name_prefix(str, optional): If not empty str, all the key in state dict will start 
                                                 with structured_name_prefix

        Retruns:
            dict: a dict contains all the parameters of self._layers

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                with fluid.dygraph.guard():
560
                    strategy=fluid.dygraph.prepare_context()
561
                    emb = fluid.dygraph.Embedding([10, 10])
562
                    emb = fluid.dygraph.DataParallel(emb, strategy)
563 564 565 566 567 568 569 570 571 572 573

                    state_dict = emb.state_dict()
                    fluid.save_dygraph( state_dict, "paddle_dy")

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

574 575 576 577 578
    @framework.deprecate_stat_dict
    def set_state_dict(self,
                       state_dict,
                       include_sublayers=True,
                       use_structured_name=True):
579
        '''
580
        Set parameters of self._layers from state_dict. All the parameters of self._layers will be reset by the tensor in the state_dict
581 582 583 584 585 586 587 588 589 590 591 592

        Parameters:
            state_dict(dict) : Dict contains all the parameters
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key. 
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

593
                import paddle   
594

595
                paddle.disable_static()
596

597
                emb = paddle.nn.Embedding(10, 10)
598
                emb = fluid.dygraph.DataParallel(emb, strategy)
599

600
                state_dict = emb.state_dict()
601
                paddle.save(state_dict, "paddle_dy.pdparams")
602

603
                para_state_dict = paddle.load("paddle_dy.pdparams")
604

605
                emb.set_state_dict(para_state_dict)
606 607 608

        '''

609 610
        self._layers.set_state_dict(
            state_dict,
611 612
            include_sublayers=include_sublayers,
            use_structured_name=use_structured_name)
613 614 615 616

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict