parallel.py 21.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
15
import six
Y
Yan Xu 已提交
16
import numpy as np
17
from collections import OrderedDict
18
from .. import core
19
from . import layers
C
chengduo 已提交
20
from . import parallel_helper
21
from .. import framework
J
Jiabin Yang 已提交
22
from . import to_variable, no_grad
23

24
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
25 26 27 28

ParallelStrategy = core.ParallelStrategy


C
chengduo 已提交
29
def prepare_context(strategy=None):
30 31 32
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
33 34 35 36 37 38 39 40
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
41
    assert framework.in_dygraph_mode() is True, \
42
        "dygraph.prepare_context should be used with dygrahp mode."
43
    place = framework._current_expected_place()
C
chengduo 已提交
44
    assert place is not None, \
45
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
46
    if isinstance(place, core.CUDAPlace):
C
chengduo 已提交
47 48
        parallel_helper._set_parallel_ctx(
            core.NCCLParallelContext(strategy, place))
49 50 51
    else:
        # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
        assert ("Only support CUDAPlace for now.")
C
chengduo 已提交
52 53
    parallel_helper._init_parallel_ctx()
    return strategy
54 55


56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
class ParallelEnv(object):
    """
    **Notes**:
        **The old class name was Env and will be deprecated. Please use new class name ParallelEnv.**

    This class is used to obtain the environment variables required for 
    the parallel execution of dynamic graph model.

    The dynamic graph parallel mode needs to be started using paddle.distributed.launch.
    By default, the related environment variable is automatically configured by this module.

    This class is generally used in with `fluid.dygraph.DataParallel` to configure dynamic graph models
    to run in parallel.

    Examples:
      .. code-block:: python

        # This example needs to run with paddle.distributed.launch, The usage is:
        #   python -m paddle.distributed.launch --selected_gpus=0,1 example.py
        # And the content of `example.py` is the code of following example.

        import numpy as np
        import paddle.fluid as fluid
        import paddle.fluid.dygraph as dygraph
        from paddle.fluid.optimizer import AdamOptimizer
        from paddle.fluid.dygraph.nn import Linear
        from paddle.fluid.dygraph.base import to_variable

        place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
        with fluid.dygraph.guard(place=place):

            # prepare the data parallel context
            strategy=dygraph.prepare_context()

            linear = Linear(1, 10, act="softmax")
            adam = fluid.optimizer.AdamOptimizer()

            # make the module become the data parallelism module
            linear = dygraph.DataParallel(linear, strategy)

            x_data = np.random.random(size=[10, 1]).astype(np.float32)
            data = to_variable(x_data)

            hidden = linear(data)
            avg_loss = fluid.layers.mean(hidden)

            # scale the loss according to the number of trainers.
            avg_loss = linear.scale_loss(avg_loss)

            avg_loss.backward()

            # collect the gradients of trainers.
            linear.apply_collective_grads()

            adam.minimize(avg_loss)
            linear.clear_gradients()
    """

114 115 116 117 118 119 120 121 122 123
    def __init__(self):
        self._nranks = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
        self._local_rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._dev_id = int(os.getenv("FLAGS_selected_gpus", "0"))
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")

    @property
    def nranks(self):
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        """
        The number of trainers, generally refers to the number of GPU cards used in training.

        Its value is equal to the value of the environment variable PADDLE_TRAINERS_NUM. The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.fluid as fluid
            
            env = fluid.dygraph.ParallelEnv()
            print("The nranks is %d" % env.nranks)
            # The nranks is 4
        """
139 140 141 142
        return self._nranks

    @property
    def local_rank(self):
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        """
        The current trainer number.

        Its value is equal to the value of the environment variable PADDLE_TRAINER_ID. The default value is 0.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.fluid as fluid
            
            env = fluid.dygraph.ParallelEnv()
            print("The local rank is %d" % env.local_rank)
            # The local rank is 0
        """
158 159 160 161
        return self._local_rank

    @property
    def dev_id(self):
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        """
        The ID of selected GPU card for parallel training.

        Its value is equal to the value of the environment variable FLAGS_selected_gpus. The default value is 0.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
            import paddle.fluid as fluid
            
            env = fluid.dygraph.ParallelEnv()
            print("The device id are %d" % env.dev_id)
            # The device id are 1
        """
177 178 179 180
        return self._dev_id

    @property
    def current_endpoint(self):
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

        Its value is equal to the value of the environment variable PADDLE_CURRENT_ENDPOINT. The default value is "".

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
            import paddle.fluid as fluid
            
            env = fluid.dygraph.ParallelEnv()
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
196
        return self._current_endpoint
197 198 199

    @property
    def trainer_endpoints(self):
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

        Its value is equal to the value of the environment variable PADDLE_TRAINER_ENDPOINTS. The default value is "".

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
            import paddle.fluid as fluid
            
            env = fluid.dygraph.ParallelEnv()
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
216 217 218
        return self._trainer_endpoints


219 220 221 222 223 224
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


225
class DataParallel(layers.Layer):
C
chengduo 已提交
226
    """
227
    Run the dygraph module with data parallelism.
C
chengduo 已提交
228

229
    Currently, DataParallel class only supports to run the dynamic graph
C
chengduo 已提交
230
    with multi-process. The usage is:
231
    `python -m paddle.distributed.launch --selected_gpus=0,1 dynamic_graph_test.py`.
C
chengduo 已提交
232 233
    And the content of `dynamic_graph_test.py` is the code of examples.

234 235 236 237 238 239 240 241
    Args:
        layers(Layer): The module that should be executed by data parallel.
        strategy(ParallelStrategy): The strategy of data parallelism, contains 
            environment configuration related to parallel execution.

    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
242 243 244
    Examples:
        .. code-block:: python

245 246
            import numpy as np
            import paddle.fluid as fluid
C
chengduo 已提交
247

248 249
            place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
            with fluid.dygraph.guard(place):
C
chengduo 已提交
250

251 252
                # prepare the data parallel context
                strategy = fluid.dygraph.prepare_context()
C
chengduo 已提交
253

254 255 256
                linear = fluid.dygraph.Linear(1, 10, act="softmax")
                adam = fluid.optimizer.AdamOptimizer(
                    learning_rate=0.001, parameter_list=linear.parameters())
C
chengduo 已提交
257

258 259
                # make the module become the data parallelism module
                linear = fluid.dygraph.DataParallel(linear, strategy)
C
chengduo 已提交
260

261 262
                x_data = np.random.random(size=[10, 1]).astype(np.float32)
                data = fluid.dygraph.to_variable(x_data)
C
chengduo 已提交
263

264 265
                hidden = linear(data)
                avg_loss = fluid.layers.mean(hidden)
C
chengduo 已提交
266

267 268
                # scale the loss according to the number of trainers.
                avg_loss = linear.scale_loss(avg_loss)
C
chengduo 已提交
269

270
                avg_loss.backward()
C
chengduo 已提交
271

272 273
                # collect the gradients of trainers.
                linear.apply_collective_grads()
C
chengduo 已提交
274

275 276
                adam.minimize(avg_loss)
                linear.clear_gradients()
C
chengduo 已提交
277 278
    """

Y
Yan Xu 已提交
279
    def __init__(self, layers, strategy):
280 281
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
282

283
        self._layers = layers
Y
Yan Xu 已提交
284
        self._strategy = strategy
285 286

    def forward(self, *inputs, **kwargs):
Y
Yan Xu 已提交
287 288 289
        return self._layers(*inputs, **kwargs)

    def scale_loss(self, loss):
C
chengduo 已提交
290 291 292 293 294 295
        """
        Scale the loss. In data parallel mode, the loss should be scale with
        the number of trainers. If not in data parallel mode, return the loss
        directly.

        Args:
296
            loss(Variable): The loss of the current Model.
C
chengduo 已提交
297 298

        Returns:
299 300 301 302 303 304 305
            Variable: the scaled loss.

        Examples:
            .. code-block:: python

                import numpy as np
                import paddle.fluid as fluid
306 307 308 309 310 311 312 313 314 315 316 317 318

                place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
                with fluid.dygraph.guard(place):

                    # prepare the data parallel context
                    strategy = fluid.dygraph.prepare_context()

                    linear = fluid.dygraph.Linear(1, 10, act="softmax")
                    adam = fluid.optimizer.AdamOptimizer(
                        learning_rate=0.001, parameter_list=linear.parameters())

                    # make the module become the data parallelism module
                    linear = fluid.dygraph.DataParallel(linear, strategy)
319 320

                    x_data = np.random.random(size=[10, 1]).astype(np.float32)
321 322
                    data = fluid.dygraph.to_variable(x_data)

323 324 325 326 327 328 329
                    hidden = linear(data)
                    avg_loss = fluid.layers.mean(hidden)

                    # scale the loss according to the number of trainers.
                    avg_loss = linear.scale_loss(avg_loss)

                    avg_loss.backward()
330 331

                    # collect the gradients of trainers.
332 333 334 335
                    linear.apply_collective_grads()

                    adam.minimize(avg_loss)
                    linear.clear_gradients()
C
chengduo 已提交
336 337
        """
        if not self._is_data_parallel_mode():
Y
Yan Xu 已提交
338
            return loss
C
chengduo 已提交
339

Y
Yan Xu 已提交
340 341 342 343 344 345
        loss_scale = to_variable(
            np.array([self._strategy.nranks]).astype("float32"))
        loss_scale.stop_gradient = True
        loss = loss / loss_scale
        return loss

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
    def _coalesce_tensors(self, var_groups):
        from ..layers import nn
        coalesced_grads_and_grad_vars = []
        for group_id, grad_vars in var_groups.items():
            flattened_vars = []
            g_var_shapes = []
            for g_var in grad_vars:
                g_var_shapes.append(g_var.shape)
                flattened_vars.append(
                    nn.reshape(
                        x=g_var, shape=[np.prod(g_var.shape)], inplace=True))
            coalesced_grad = nn.concat(flattened_vars)
            coalesced_grads_and_grad_vars.append(
                [coalesced_grad, grad_vars, g_var_shapes])
        return coalesced_grads_and_grad_vars

362 363 364 365 366 367 368 369 370
    def _reshape_inplace(self, x, shape):
        x_shape = self._helper.create_variable_for_type_inference(dtype=x.dtype)
        self._helper.append_op(
            type="reshape2",
            inputs={'X': x},
            attrs={'shape': shape},
            outputs={'Out': x,
                     'XShape': x_shape})

371 372 373 374
    def _split_tensors(self, coalesced_grads_and_grad_vars):
        from ..layers import nn
        for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
            grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
375 376 377 378 379 380 381
            self._helper.main_program.current_block().append_op(
                type='split',
                inputs={'X': coalesced_grad},
                outputs={'Out': origin_grad_vars},
                attrs={'sections': grad_var_len,
                       'axis': 0})
            for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
382 383
                self._reshape_inplace(x=g_var, shape=g_shape)
                assert g_var.shape == g_shape
384

385
    @no_grad()
Y
Yan Xu 已提交
386
    def apply_collective_grads(self):
C
chengduo 已提交
387 388
        """
        AllReduce the Parameters' gradient.
389 390 391 392 393 394

        Examples:
            .. code-block:: python

                import numpy as np
                import paddle.fluid as fluid
395 396 397 398 399 400 401 402 403 404 405 406 407

                place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
                with fluid.dygraph.guard(place):

                    # prepare the data parallel context
                    strategy = fluid.dygraph.prepare_context()

                    linear = fluid.dygraph.Linear(1, 10, act="softmax")
                    adam = fluid.optimizer.AdamOptimizer(
                        learning_rate=0.001, parameter_list=linear.parameters())

                    # make the module become the data parallelism module
                    linear = fluid.dygraph.DataParallel(linear, strategy)
408 409

                    x_data = np.random.random(size=[10, 1]).astype(np.float32)
410 411
                    data = fluid.dygraph.to_variable(x_data)

412 413
                    hidden = linear(data)
                    avg_loss = fluid.layers.mean(hidden)
414 415

                    # scale the loss according to the number of trainers.
416
                    avg_loss = linear.scale_loss(avg_loss)
417

418 419 420 421 422 423 424
                    avg_loss.backward()

                    # collect the gradients of trainers.
                    linear.apply_collective_grads()

                    adam.minimize(avg_loss)
                    linear.clear_gradients()
C
chengduo 已提交
425 426
        """
        if not self._is_data_parallel_mode():
Y
Yan Xu 已提交
427 428
            return

429 430
        grad_var_set = set()
        grad_vars = []
431
        sparse_grad_vars = []
Y
Yan Xu 已提交
432
        for param in self._layers.parameters():
C
chengduo 已提交
433
            # NOTE(zcd): The grad_ivar maybe no generated.
434
            if param.trainable and (param._grad_ivar() is not None):
435
                g_var = param._grad_ivar()
436 437 438
                if g_var._is_sparse():
                    sparse_grad_vars.append(g_var)
                    continue
439 440 441 442
                grad_vars.append(g_var)
                assert g_var not in grad_var_set
                grad_var_set.add(g_var)

443 444 445 446 447
        if sparse_grad_vars:
            sparse_grad_vars.sort(key=lambda x: x.name)
            for grad_var in sparse_grad_vars:
                grad_var._allreduce(self._strategy)

448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
        # FIXME(zcd): the type of the var should be LoDTensor, i.e
        # the gradients should be dense, otherwise, the following
        # logic should be updated.
        # 128 MB as a group
        mega_bytes = 128 * 1024 * 1024
        group_idx = 0
        memory_counter = 0
        grad_var_groups = OrderedDict()
        dtype = grad_vars[0].dtype
        for g_var in grad_vars:
            # Note: the dtype of the same group should be the same.
            bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype)
            if memory_counter < mega_bytes and dtype == g_var.dtype:
                memory_counter += bytes
            else:
                memory_counter = bytes
                group_idx += 1
            grad_var_groups.setdefault(group_idx, []).append(g_var)

        coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups)

469 470
        for coalesced_grad, _, _ in coalesced_grads_and_vars:
            coalesced_grad._allreduce(self._strategy)
471 472

        self._split_tensors(coalesced_grads_and_vars)
C
chengduo 已提交
473 474 475

    def _is_data_parallel_mode(self):
        return self._strategy.nranks > 1
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
        Get all parameters of self._layers and its sub-layers. And set all the parameters into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters will set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
            structured_name_prefix(str, optional): If not empty str, all the key in state dict will start 
                                                 with structured_name_prefix

        Retruns:
            dict: a dict contains all the parameters of self._layers

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                with fluid.dygraph.guard():
498
                    strategy=fluid.dygraph.prepare_context()
499
                    emb = fluid.dygraph.Embedding([10, 10])
500
                    emb = fluid.dygraph.DataParallel(emb, strategy)
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531

                    state_dict = emb.state_dict()
                    fluid.save_dygraph( state_dict, "paddle_dy")

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

    def set_dict(self,
                 stat_dict,
                 include_sublayers=True,
                 use_structured_name=True):
        '''
        Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict

        Parameters:
            state_dict(dict) : Dict contains all the parameters
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key. 
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                with fluid.dygraph.guard():
532
                    strategy=fluid.dygraph.prepare_context()
533
                    emb = fluid.dygraph.Embedding([10, 10])
534
                    emb = fluid.dygraph.DataParallel(emb, strategy)
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571

                    state_dict = emb.state_dict()
                    fluid.save_dygraph( state_dict, "paddle_dy")
                    
                    para_state_dict, _ = fluid.load_dygraph( "paddle_dy")

                    emb.set_dict( para_state_dict )

        '''

        self._layers.set_dict(
            stat_dict,
            include_sublayers=include_sublayers,
            use_structured_name=use_structured_name)

    def load_dict(self,
                  stat_dict,
                  include_sublayers=True,
                  use_structured_name=True):
        '''
        Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict

        This api will be Deprecated. Please use set_dict

        Parameters:
            state_dict(dict) : Dict contains all the parameters
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key.
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                with fluid.dygraph.guard():
572
                    strategy=fluid.dygraph.prepare_context()
573
                    emb = fluid.dygraph.Embedding([10, 10])
574
                    emb = fluid.dygraph.DataParallel(emb, strategy)
575 576 577 578 579 580 581 582 583 584 585 586 587 588

                    state_dict = emb.state_dict()
                    fluid.save_dygraph( state_dict, "paddle_dy")
                    
                    para_state_dict, _ = fluid.load_dygraph( "paddle_dy")

                    emb.load_dict( para_state_dict )

        '''

        self._layers.load_dict(
            stat_dict,
            include_sublayers=include_sublayers,
            use_structured_name=use_structured_name)