parallel.py 25.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import os
16
import six
Y
Yan Xu 已提交
17
import numpy as np
18
import warnings
19
from collections import OrderedDict
20 21 22 23 24 25 26

from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
27
from ..layers import collective
28
from paddle.fluid.dygraph import base as imperative_base
29
import warnings
30
import paddle
31
import itertools
32

33
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
34 35 36 37

ParallelStrategy = core.ParallelStrategy


38
@deprecated(since="2.0.0", update_to="paddle.distributed.init_parallel_env")
C
chengduo 已提交
39
def prepare_context(strategy=None):
40 41 42
    '''
    :api_attr: imperative
    '''
C
chengduo 已提交
43 44 45 46 47 48 49 50
    if strategy is None:
        strategy = ParallelStrategy()
        strategy.nranks = Env().nranks
        strategy.local_rank = Env().local_rank
        strategy.trainer_endpoints = Env().trainer_endpoints
        strategy.current_endpoint = Env().current_endpoint
    if strategy.nranks < 2:
        return
51
    assert framework.in_dygraph_mode() is True, \
52
        "dygraph.prepare_context should be used with dygraph mode."
53
    place = framework._current_expected_place()
C
chengduo 已提交
54
    assert place is not None, \
55
        "dygraph.prepare_context should be used in fluid.dygraph.guard(place) guard."
56 57 58 59
    if not parallel_helper._is_parallel_ctx_initialized():
        if isinstance(place, core.CUDAPlace):
            parallel_helper._set_parallel_ctx(
                core.NCCLParallelContext(strategy, place))
60 61 62
        elif isinstance(place, core.XPUPlace):
            parallel_helper._set_parallel_ctx(
                core.BKCLParallelContext(strategy, place))
63 64
        else:
            # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
65
            assert ("Only support CUDAPlace or XPUPlace for now.")
66
        parallel_helper._init_parallel_ctx()
C
chengduo 已提交
67
    return strategy
68 69


70 71
class ParallelEnv(object):
    """
72 73 74 75
    .. note::
        This API is not recommended, if you need to get rank and world_size, 
        it is recommended to use ``paddle.distributed.get_rank()`` and 
        ``paddle.distributed.get_world_size()`` .
76 77

    This class is used to obtain the environment variables required for 
78
    the parallel execution of ``paddle.nn.Layer`` in dynamic mode.
79

80
    The parallel execution in dynamic mode needs to be started using ``paddle.distributed.launch``
81
    or ``paddle.distributed.spawn`` .
82 83 84 85

    Examples:
      .. code-block:: python

86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        import paddle
        import paddle.distributed as dist

        def train():
            # 1. initialize parallel environment
            dist.init_parallel_env()

            # 2. get current ParallelEnv
            parallel_env = dist.ParallelEnv()
            print("rank: ", parallel_env.rank)
            print("world_size: ", parallel_env.world_size)

            # print result in process 1:
            # rank: 1
            # world_size: 2
            # print result in process 2:
            # rank: 2
            # world_size: 2

        if __name__ == '__main__':
            # 1. start by ``paddle.distributed.spawn`` (default)
            dist.spawn(train, nprocs=2)
            # 2. start by ``paddle.distributed.launch``
            # train()
110 111
    """

112
    def __init__(self):
113 114
        self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0"))
        self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
115

116 117 118 119 120 121 122
        # imperative only support one gpu or xpu
        if core.is_compiled_with_cuda():
            selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",")
            self._device_id = int(selected_gpus[0])
        elif core.is_compiled_with_xpu():
            selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
            self._device_id = int(selected_xpus[0])
123

124 125 126
        self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
                                            "").split(",")
        self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
127 128 129 130 131
        self._nrings = int(os.getenv("FLAGS_nccl_nrings", "1"))
        assert self._nrings > 0, \
            "nccl_nrings must be an integer greater than 0."
        assert self._nrings < 9, \
            "nccl_nrings should be less than 9, which is enough in most scenarios."
132 133

    @property
134
    def rank(self):
135
        """
136
        Rank of current trainer.
137

138
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0.
139 140 141 142

        Examples:
          .. code-block:: python

143 144
            # execute this command in terminal: export PADDLE_TRAINER_ID=0
            import paddle.distributed as dist
145
            
146 147 148
            env = dist.ParallelEnv()
            print("The rank is %d" % env.rank)
            # The rank is 0
149
        """
150
        return self._rank
151 152

    @property
153
    def world_size(self):
154
        """
155
        The number of trainers (number of processes participating in current job).
156

157
        Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1.
158 159 160 161

        Examples:
          .. code-block:: python

162 163
            # execute this command in terminal: export PADDLE_TRAINERS_NUM=4
            import paddle.distributed as dist
164
            
165 166 167
            env = dist.ParallelEnv()
            print("The world_size is %d" % env.world_size)
            # The world_size is 4
168
        """
169
        return self._world_size
170 171

    @property
172
    def device_id(self):
173 174 175
        """
        The ID of selected GPU card for parallel training.

176
        Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0.
177 178 179 180 181

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_selected_gpus=1
182
            import paddle.distributed as dist
183
            
184 185
            env = dist.ParallelEnv()
            print("The device id are %d" % env.device_id)
186 187
            # The device id are 1
        """
188
        return self._device_id
189 190 191

    @property
    def current_endpoint(self):
192 193 194
        """
        The endpoint of current trainer, it is in the form of (node IP + port).

195
        Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "".
196 197 198 199 200

        Examples:
          .. code-block:: python
            
            # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170
201
            import paddle.distributed as dist
202
            
203
            env = dist.ParallelEnv()
204 205 206
            print("The current endpoint are %s" % env.current_endpoint)
            # The current endpoint are 127.0.0.1:6170
        """
207
        return self._current_endpoint
208 209 210

    @property
    def trainer_endpoints(self):
211 212 213 214
        """
        The endpoints of all trainer nodes in the task, 
        which are used to broadcast the NCCL ID when NCCL2 is initialized.

215
        Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "".
216 217 218 219 220

        Examples:
          .. code-block:: python

            # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171
221
            import paddle.distributed as dist
222
            
223
            env = dist.ParallelEnv()
224 225 226
            print("The trainer endpoints are %s" % env.trainer_endpoints)
            # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171']
        """
227 228
        return self._trainer_endpoints

229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
    @property
    def nrings(self):
        """
        Nrings of current trainer.

        Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1.

        Examples:
          .. code-block:: python

            # execute this command in terminal: export FLAGS_nccl_nrings=1
            import paddle.distributed as dist
            
            env = dist.ParallelEnv()
            print("The nrings is %d" % env.nrings)
            # the number of ring is 1
        """
        return self._nrings

248 249 250 251 252
    # [aliases] Compatible with old method names
    local_rank = rank
    nranks = world_size
    dev_id = device_id

253

254 255 256 257 258 259
# NOTE: [ Compatible ] Originally this class name is `Env`. The semantics of the old class names
# are inaccurate and may confuse users, so replace it with `ParallelEnv`, but to be compatible
# with the old examples, here still need to keep this name.
Env = ParallelEnv


260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
def _build_default_parallel_strategy():
    strategy = ParallelStrategy()
    strategy.nranks = ParallelEnv().nranks
    strategy.local_rank = ParallelEnv().local_rank
    strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
    strategy.current_endpoint = ParallelEnv().current_endpoint
    return strategy


def _coalesce_tensors(var_groups):
    from ..layers import nn
    coalesced_grads_and_grad_vars = []
    for group_id, grad_vars in var_groups.items():
        flattened_vars = []
        g_var_shapes = []
        for g_var in grad_vars:
            g_var_shapes.append(g_var.shape)
            flattened_vars.append(
                nn.reshape(
                    x=g_var, shape=[np.prod(g_var.shape)]))
        coalesced_grad = nn.concat(flattened_vars)
        coalesced_grads_and_grad_vars.append(
            [coalesced_grad, grad_vars, g_var_shapes])
    return coalesced_grads_and_grad_vars


@framework.dygraph_only
def _reshape_inplace(x, shape):
    x_shape = framework._varbase_creator(dtype=x.dtype)
    framework._dygraph_tracer().trace_op(
        type="reshape2",
        inputs={'X': x},
        outputs={'Out': x,
                 'XShape': x_shape},
        attrs={'shape': shape})


@framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars):
    for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
        grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
        framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_grad},
            outputs={'Out': origin_grad_vars},
            attrs={'sections': grad_var_len,
                   'axis': 0})
        for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
            _reshape_inplace(x=g_var, shape=g_shape)
            assert g_var.shape == g_shape


def scale_loss(loss):
313
    # TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
314 315 316 317 318 319 320 321 322 323
    if not ParallelEnv().world_size > 1:
        return loss

    loss_scale = to_variable(
        np.array([ParallelEnv().world_size]).astype("float32"))
    loss_scale.stop_gradient = True
    scaled_loss = loss / loss_scale
    return scaled_loss


324 325
@imperative_base.no_grad
@framework.dygraph_only
326
def build_groups(vars, group_size):
327 328 329 330 331 332 333 334 335 336
    group_idx = 0
    memory_counter = 0
    var_groups = OrderedDict()
    dtype = vars[0].dtype

    for var in vars:
        bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
        if memory_counter < group_size and dtype == var.dtype:
            memory_counter += bytes
        else:
337
            memory_counter = bytes
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
            dtype = var.dtype
            group_idx += 1
        var_groups.setdefault(group_idx, []).append(var)
    return _coalesce_tensors(var_groups)


@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(model,
                        comm_group=None,
                        src_rank=0,
                        is_model_parallel=False):
    model_vars = []
    for _, param in model.state_dict().items():
        if not isinstance(param, core.VarBase):
            raise TypeError("The data type of '%s' must be Varbase" %
                            param.name)
        # is_distributed param not need to sync when in mp mode
        if is_model_parallel and param.is_distributed:
            continue

        model_vars.append(param.detach())
    if len(model_vars) == 0:
        return

    # group size is 128M
364
    coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379

    for coalesced_var, _, _ in coalesced_vars:
        paddle.distributed.broadcast(
            coalesced_var, src=src_rank, group=comm_group, use_calc_stream=True)

    for coalesced_var, origin_vars, var_shapes in coalesced_vars:
        var_len = [np.prod(v_shape) for v_shape in var_shapes]
        paddle.fluid.framework._dygraph_tracer().trace_op(
            type='split',
            inputs={'X': coalesced_var},
            outputs={'Out': origin_vars},
            attrs={'sections': var_len,
                   'axis': 0})


380
class DataParallel(layers.Layer):
C
chengduo 已提交
381
    """
382
    Run the dygraph module with data parallelism.
C
chengduo 已提交
383

384
    Currently, DataParallel class only supports to run the dynamic graph
385 386 387 388 389 390 391 392 393 394
    with multi-process. 
    
    Now supports two ways to start training:

    1. start by ``paddle.distributed.spawn`` method, for example:

        ``python demo.py`` (spawn need to be called in ``__main__`` method)
    
    2. start by ``paddle.distributed.launch`` module, for example:
    
395
        ``python -m paddle.distributed.launch --gpus=0,1 demo.py`` .
396 397

    And the content of `demo.py` is the code of examples.
C
chengduo 已提交
398

399 400
    Args:
        layers(Layer): The module that should be executed by data parallel.
401 402
        strategy(ParallelStrategy, optional): (deprecated) The strategy of data parallelism, 
            contains environment configuration related to parallel execution. Default: None.
403
        comm_buffer_size(int, optional):  It limits the memory size(MB) of one buffer  
404 405
                                          parameters' gradient which is the input of communication 
                                          calling(e.g NCCLAllReduce). Default: 25.
406 407
        last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
                                         calling. Making the last communication buffer size small is useful to 
408
                                         improve performance. Default: 1.
409 410 411 412 413 414 415 416 417 418 419 420
        find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
                                                all tensors in the return value of the wrapped model's 
                                                forward function. For parameters not involved in loss 
                                                calculation, their gradients will be marked as ready in 
                                                advance to prepare reduce. Please note that all forward 
                                                outputs derived from the wrapped model parameters must 
                                                participate in the calculation of loss and subsequent 
                                                gradient calculations. If not, serious error will occur.
                                                Note that setting the find_unused_parameters to True 
                                                will affect computing performance. Therefore, if all parameters
                                                are sure to participate in the loss calculation and the 
                                                autograd graph construction, please set it False. Default: True.
421
            
422 423 424
    Returns:
        Layer: The data paralleled module.

C
chengduo 已提交
425 426 427
    Examples:
        .. code-block:: python

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
            import paddle
            import paddle.nn as nn
            import paddle.optimizer as opt
            import paddle.distributed as dist

            class LinearNet(nn.Layer):
                def __init__(self):
                    super(LinearNet, self).__init__()
                    self._linear1 = nn.Linear(10, 10)
                    self._linear2 = nn.Linear(10, 1)
                    
                def forward(self, x):
                    return self._linear2(self._linear1(x))

            def train():
443
                # 1. initialize parallel environment
444 445
                dist.init_parallel_env()

446
                # 2. create data parallel layer & optimizer
447 448 449 450 451 452 453
                layer = LinearNet()
                dp_layer = paddle.DataParallel(layer)

                loss_fn = nn.MSELoss()
                adam = opt.Adam(
                    learning_rate=0.001, parameters=dp_layer.parameters())

454
                # 3. run layer
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
                inputs = paddle.randn([10, 10], 'float32')
                outputs = dp_layer(inputs)
                labels = paddle.randn([10, 1], 'float32')
                loss = loss_fn(outputs, labels)
                
                loss.backward()

                adam.step()
                adam.clear_grad()

            if __name__ == '__main__':
                # 1. start by ``paddle.distributed.spawn`` (default)
                dist.spawn(train, nprocs=2)
                # 2. start by ``paddle.distributed.launch``
                # train()
C
chengduo 已提交
470 471
    """

472 473 474
    def __init__(self,
                 layers,
                 strategy=None,
475
                 comm_buffer_size=25,
476 477
                 last_comm_buffer_size=1,
                 find_unused_parameters=True):
478 479
        super(DataParallel,
              self).__init__(layers.full_name() + "_data_parallel")
C
chengduo 已提交
480

481
        self._layers = layers
482
        self.find_unused_parameters = find_unused_parameters
483 484 485 486 487 488 489 490

        # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. 
        # It just stores some environment variables, which can be constructed by 
        # ParallelEnv. Here it is set as an optional argument.
        # This parameter is not removed because of compatibility with 1.x writing.
        if strategy is not None:
            self._strategy = strategy
        else:
491
            self._strategy = _build_default_parallel_strategy()
492

493
        if self._strategy.nranks > 1:
494 495 496 497 498 499 500 501 502
            # check the environment
            assert parallel_helper.__parallel_ctx__clz__ is not None, \
            "ParallelContext must be initialized before. You should use init_parallel_env() before" \
            "constructing the DataParallel."

            # sync buffer and params
            # TODO(liuyuhui) Currently not support xpu. xpu is 
            # still broadcasting parameters when calling layer
            if not paddle.is_compiled_with_xpu():
503
                sync_params_buffers(self._layers)
504

505
            self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
506 507 508 509
            # NOTE(shenliang03): We can set environment variables to control 
            # the size of the group, Default: 1MB. The role of this small group is: 
            # when the last group allreduce, the overlap cannot work. Making the 
            # the last group small is useful to improve performance.
510 511
            self.last_comm_buffer_size = int(last_comm_buffer_size * 1024 *
                                             1024)
512 513
            self.init_reducer()
        else:
S
ShenLiang 已提交
514 515
            warnings.warn("The program will return to single-card operation. "
                          "Please check 1, whether you use spawn or fleetrun "
516 517
                          "to start the program. 2, Whether it is a multi-card "
                          "program. 3, Is the current environment multi-card.")
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534

    def init_reducer(self):
        layers_param = []
        params_set = set()
        for sublayer in self.sublayers():
            for _, param in sublayer.named_parameters(include_sublayers=False):
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                if not isinstance(param, core.VarBase):
                    raise TypeError("The data type of '%s' must be Varbase" %
                                    param.name)
                if param.trainable:
                    layers_param.append((sublayer, param))

        trainable_parameters = [param for _, param in layers_param]

535 536 537 538
        assert len(trainable_parameters) > 0, \
            "This model does not have any parameters to train, and " \
            "does not need to use DataParallel"

539 540 541
        # NOTE(shenliang03): Here we can only use the attributes to judge whether
        # parameter is sparse(or SelectedRows). The reason is that the sparse message
        # can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
542
        # we should add the layer here like "paddle.nn.layer.common.Embedding".
543
        def check_layer_sparse(sublayer):
544 545 546
            if isinstance(sublayer, paddle.nn.layer.common.Embedding):
                return sublayer._sparse
            # NOTE(shenliang03):This is for compatibility. If paddle.fluid.dygraph.Embedding 
547
            # is removed in the future, the check will also be removed here.
548
            if isinstance(sublayer, paddle.fluid.dygraph.Embedding):
549 550 551 552 553 554 555 556 557
                return sublayer._is_sparse
            return False

        is_sparse_gradient = [
            check_layer_sparse(sublayer) for sublayer, _ in layers_param
        ]

        self.group_indices = core.assign_group_by_size(
            trainable_parameters, is_sparse_gradient,
558
            [self.last_comm_buffer_size, self.comm_buffer_size])
559

560 561 562 563
        self._reducer = core.Reducer(
            trainable_parameters,
            list(reversed(self.group_indices)), is_sparse_gradient,
            parallel_helper.__parallel_ctx__clz__,
564
            [self.last_comm_buffer_size, self.comm_buffer_size],
565
            self.find_unused_parameters)
566 567 568 569 570 571 572 573 574

    def _find_varbase(self, obj):
        if isinstance(obj, core.VarBase):
            return [obj]
        if isinstance(obj, (list, tuple)):
            return itertools.chain(*map(self._find_varbase, obj))
        if isinstance(obj, dict):
            return itertools.chain(*map(self._find_varbase, obj.values()))
        return []
575

576
    def forward(self, *inputs, **kwargs):
577
        outputs = self._layers(*inputs, **kwargs)
578 579 580 581 582 583
        if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad:
            if self.find_unused_parameters:
                self._reducer.prepare_for_backward(
                    list(self._find_varbase(outputs)))
            else:
                self._reducer.prepare_for_backward(list(self._find_varbase([])))
584

585
        return outputs
Y
Yan Xu 已提交
586

587 588
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
589
    def scale_loss(self, loss):
C
chengduo 已提交
590
        """
591 592
        Deprecated method, now ``scale_loss`` is an empty method,  
        keep this method just for compatibility.
C
chengduo 已提交
593
        """
Y
Yan Xu 已提交
594 595
        return loss

596 597
    @deprecated(
        since="2.0.0", reason="This method does not need to be called anymore.")
Y
Yan Xu 已提交
598
    def apply_collective_grads(self):
C
chengduo 已提交
599
        """
600 601
        Deprecated method, now ``apply_collective_grads`` is an empty method, 
        keep this method just for compatibility.
C
chengduo 已提交
602
        """
603
        return
604 605 606 607 608 609

    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
        '''
610
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
611 612

        Parameters:
613 614
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
615 616

        Retruns:
617
            dict: a dict contains all the parameters and persistable buffers.
618 619 620 621

        Examples:
            .. code-block:: python

622 623 624 625 626 627 628
                import paddle
                import paddle.distributed as dist

                dist.init_parallel_env()

                emb = fluid.dygraph.Embedding([10, 10])
                emb = fluid.dygraph.DataParallel(emb)
629

630 631
                state_dict = emb.state_dict()
                paddle.save(state_dict, "paddle_dy.pdparams")
632 633 634 635 636 637 638 639

        '''

        return self._layers.state_dict(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix)

640
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
641
    def set_state_dict(self, state_dict, use_structured_name=True):
642
        '''
643
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
644 645

        Parameters:
646 647
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
648 649 650 651 652 653 654
                                                  Default: True
        Returns:
            None

        Examples:
            .. code-block:: python

655 656
                import paddle
                import paddle.distributed as dist
657

658
                dist.init_parallel_env()
659

660
                emb = paddle.nn.Embedding(10, 10)
661
                emb = fluid.dygraph.DataParallel(emb)
662

663
                state_dict = emb.state_dict()
664
                paddle.save(state_dict, "paddle_dy.pdparams")
665

666
                para_state_dict = paddle.load("paddle_dy.pdparams")
667
                emb.set_state_dict(para_state_dict)
668 669 670

        '''

671
        self._layers.set_state_dict(
J
Jiabin Yang 已提交
672
            state_dict, use_structured_name=use_structured_name)
673 674 675 676

    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict