layers.py 27.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
Xin Pan 已提交
15
import collections
16 17 18
import contextlib
import sys
import numpy as np
M
minqiyang 已提交
19
import collections
20
import six
21
import re
C
chengduo 已提交
22
from . import parallel_helper
X
Xin Pan 已提交
23
from .. import unique_name
24
from paddle.fluid import core
25
from .layer_object_helper import LayerObjectHelper
26
from .base import program_desc_tracing_guard, param_guard
27
from paddle.fluid import framework
28
from ..param_attr import ParamAttr
H
hong 已提交
29
import copy
30
import weakref
H
hong 已提交
31
import warnings
32

33
__all__ = ['Layer']
34

35 36 37 38 39 40 41 42
_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re = re.compile('([a-z])([A-Z])')


def _convert_camel_to_snake(name):
    s1 = _first_cap_re.sub(r'\1_\2', name)
    return _all_cap_re.sub(r'\1_\2', s1).lower()

43

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
class HookRemoveHelper(object):
    """ A HookRemoveHelper that can be used to remove hook. """

    next_hook_id = 0

    def __init__(self, hooks):
        self._hooks_ref = weakref.ref(hooks)
        self._hook_id = HookRemoveHelper.next_hook_id
        HookRemoveHelper.next_hook_id += 1

    def remove(self):
        hooks = self._hooks_ref()
        if hooks is not None and self._hook_id in hooks:
            del hooks[self._hook_id]


X
Xin Pan 已提交
60
class Layer(core.Layer):
61 62 63 64 65 66
    """
    :alias_main: paddle.nn.Layer
	:alias: paddle.nn.Layer
	:old_api: paddle.fluid.dygraph.layers.Layer

    Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
X
Xin Pan 已提交
67

68
    Parameters:
69 70
        name_scope (str, optional): prefix name used by the layer to name parameters.
            If prefix is "my_layer", parameter name in MyLayer
71 72 73
            can be "my_layer_0.w_n", where "w" is the parameter
            base name and "n" is an unique suffix auto-generated.
            If None, prefix name will be snake cased class name. Default: None.
74 75 76 77 78 79 80
        dtype(str or core.VarDesc.VarType, optional): data type of this parameter.
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
                Default: ``core.VarDesc.VarType.FP32``
    
    Returns:
        None
X
Xin Pan 已提交
81
    """
X
Xin Pan 已提交
82

83
    def __init__(self, name_scope=None, dtype=core.VarDesc.VarType.FP32):
84
        self.training = True
85
        if name_scope is None:
86 87
            name_scope = _convert_camel_to_snake(self.__class__.__name__)
        self._full_name = unique_name.generate(name_scope)
88
        self._helper = LayerObjectHelper(self._full_name)
X
Xin Pan 已提交
89
        self._built = False
M
minqiyang 已提交
90
        self._dtype = dtype
91

X
Xin Pan 已提交
92 93
        self._parameters = collections.OrderedDict()
        self._sub_layers = collections.OrderedDict()
L
lujun 已提交
94
        self._loaddict_holder = collections.OrderedDict()
95

96 97 98
        self._forward_pre_hooks = collections.OrderedDict()
        self._forward_post_hooks = collections.OrderedDict()

M
minqiyang 已提交
99
    def train(self):
100 101 102 103 104 105 106 107
        """
        Sets this Layer and all its sublayers to training mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
        """
        # global setting
M
minqiyang 已提交
108
        framework._dygraph_tracer().train_mode()
109 110 111 112
        # Layer-level setting
        self.training = True
        for layer in self.sublayers():
            layer.train()
M
minqiyang 已提交
113 114

    def eval(self):
115 116 117 118 119 120 121 122
        """
        Sets this Layer and all its sublayers to evaluation mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
        """
        # global setting
M
minqiyang 已提交
123
        framework._dygraph_tracer().eval_mode()
124 125 126 127
        # Layer-level setting
        self.training = False
        for layer in self.sublayers():
            layer.eval()
M
minqiyang 已提交
128

X
Xin Pan 已提交
129
    def full_name(self):
130
        """Full name for this layer, composed by name_scope + "/" + MyLayer.__class__.__name__
X
Xin Pan 已提交
131

132 133
        Returns:
            str: full name of this layer.
X
Xin Pan 已提交
134 135 136
        """
        return self._full_name

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    def register_forward_post_hook(self, hook):
        """Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed.

        It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively.
        User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.
 
        hook(Layer, input, output) -> None or modified output

        Parameters:
            hook(function): a function registered as a forward post-hook

        Returns:
            HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
Z
zhongpu 已提交
155
              import numpy as np
156 157 158 159 160 161 162 163 164 165 166 167 168 169

              # the forward_post_hook change the output of the layer: output = output * 2 
              def forward_post_hook(layer, input, output):
                  # user can use layer, input and output for information statistis tasks

                  # change the output 
                  return output * 2

              with fluid.dygraph.guard():
                  linear = fluid.Linear(13, 5, dtype="float32")

                  # register the hook
                  forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
                  
Z
zhongpu 已提交
170 171
                  value1 = np.arange(26).reshape(2, 13).astype("float32")
                  in1 = fluid.dygraph.to_variable(value1)
172
                  
Z
zhongpu 已提交
173
                  out0 = linear(in1)
174 175 176 177
                  
                  # remove the hook
                  forward_post_hook_handle.remove()

Z
zhongpu 已提交
178
                  out1 = linear(in1)
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206

                  # hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
                  assert (out0.numpy() == (out1.numpy()) * 2).any()
        """
        hook_remove_helper = HookRemoveHelper(self._forward_post_hooks)
        self._forward_post_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

    def register_forward_pre_hook(self, hook):
        """Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed.
        
        It should have the following form, `input` of the `hook` is `input` of the `Layer`,
        hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if 
        a single value is returned(unless that value is already a tuple).
        User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.

        hook(Layer, input) -> None or modified input

        Parameters:
            hook(function): a function registered as a forward pre-hook

        Returns:
            HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
Z
zhongpu 已提交
207
              import numpy as np
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240

              # the forward_post_hook change the input of the layer: input = input * 2
              def forward_pre_hook(layer, input):
                  # user can use layer and input for information statistis tasks

                  # change the input
                  input_return = (input[0] * 2)
                  return input_return

              with fluid.dygraph.guard():
                  linear = fluid.Linear(13, 5, dtype="float32")

                  # register the hook
                  forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)

                  value0 = np.arange(26).reshape(2, 13).astype("float32")
                  in0 = fluid.dygraph.to_variable(value0)
                  out0 = linear(in0)

                  # remove the hook
                  forward_pre_hook_handle.remove()

                  value1 = value0 * 2
                  in1 = fluid.dygraph.to_variable(value1)
                  out1 = linear(in1)

                  # hook change the linear's input to input * 2, so out0 is equal to out1.
                  assert (out0.numpy() == out1.numpy()).any()
        """
        hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
        self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

241 242
    def create_parameter(self,
                         shape,
243 244
                         attr=None,
                         dtype='float32',
245 246
                         is_bias=False,
                         default_initializer=None):
247 248 249
        """Create parameters for this layer.
        
        Parameters:
250 251 252
            shape(list): Shape of the parameter.
            attr(ParamAttr, optional): Parameter attribute of weight. Please refer to :ref:`api_fluid_ParamAttr`. Default: None.
            dtype(str or core.VarDesc.VarType or str, optional): Data type of this parameter.
253
                If set str, it can be "bool",  "float16", "float32", "float64",
254 255
                "int8", "int16", "int32", "int64", "uint8" or "uint16". Default: "float32".
            is_bias(bool, optional): if this is a bias parameter. Default: False.
256 257
            default_initializer(Initializer, optional): the default initializer for this parameter.
                If set None, default initializer will be set to :ref:`api_fluid_initializer_XavierInitializer` and :ref:`api_fluid_initializer_ConstantInitializer`
258
                for non-bias and bias parameter, respectively. Default: None.
259

260 261
        Returns:
            :ref:`api_guide_Variable_en` : created parameter.
262
        """
H
hong 已提交
263 264 265 266
        temp_attr = copy.deepcopy(attr)
        if isinstance(temp_attr, six.string_types) and temp_attr == "":
            temp_attr = None
        return self._helper.create_parameter(temp_attr, shape, dtype, is_bias,
267 268 269 270 271 272 273 274
                                             default_initializer)

    # TODO: Add more parameter list when we need them
    def create_variable(self,
                        name=None,
                        persistable=None,
                        dtype=None,
                        type=core.VarDesc.VarType.LOD_TENSOR):
275
        """Create Variable for this layer.
276

277 278 279 280 281 282 283 284
        Parameters:
            name(str, optional): name of the variable. Please refer to :ref:`api_guide_Name` . Default: None
            persistable(bool, optional): if set this variable persistable. Default: False
            dtype(str or core.VarDesc.VarType, optional): data type of this parameter.
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
                If set None, it will be ``core.VarDesc.VarType.FP32``. Default: None
            type(core.VarDesc.VarType, optional): type of the variable. No need to set this parameter. Default: ``core.VarDesc.VarType.LOD_TENSOR``
285

286 287
        Returns:
            :ref:`api_guide_Variable_en` : created Variable.
288 289 290 291 292 293 294 295 296 297
        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
            var_name = unique_name.generate(".".join(
                [self._full_name, "_generated_var"]))

        return self._helper.main_program.current_block().create_var(
            name=var_name, persistable=persistable, dtype=dtype, type=type)

X
polish  
Xin Pan 已提交
298
    def parameters(self, include_sublayers=True):
299
        """Returns a list of all Parameters from current layer and its sub-layers.
X
Xin Pan 已提交
300

301 302
        Parameters:
            include_sublayers(bool, optional): Whether include the parameters of sublayers. If True, also include the parameters from sublayers. Default: True
X
Xin Pan 已提交
303

304 305
        Returns:
            list of :ref:`api_guide_Variable_en` : a list of Parameters.
X
Xin Pan 已提交
306
        """
307 308 309 310 311
        ret = [
            param
            for _, param in self.named_parameters(
                include_sublayers=include_sublayers)
        ]
X
polish  
Xin Pan 已提交
312
        return ret
X
Xin Pan 已提交
313

X
Xin Pan 已提交
314 315 316
    def sublayers(self, include_sublayers=True):
        """Returns a list of sub layers.

317 318
        Parameters:
            include_sublayers(bool, optional): Whether return the sublayers of sublayers. If True, also include the sublayers of sublayers. Default: True
X
Xin Pan 已提交
319

320 321
        Returns:
            list of Layer : a list of sub layers.
X
Xin Pan 已提交
322
        """
323 324 325 326 327
        ret = [
            layer
            for _, layer in self.named_sublayers(
                include_sublayers=include_sublayers)
        ]
X
Xin Pan 已提交
328 329
        return ret

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
    def named_parameters(self, prefix='', include_sublayers=True):
        """
        Returns an iterator over all parameters in the Layer, yielding tuple of name and parameter.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_sublayers(bool, optional): Whether include the parameters of sublayers.
                If True, also include the named parameters from sublayers. Default: True.

        Yields:
            (string, Parameter): Tuple of name and Parameter

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid

                with fluid.dygraph.guard():
                    fc1 = fluid.Linear(10, 3)
                    fc2 = fluid.Linear(3, 10, bias_attr=False)
                    model = fluid.dygraph.Sequential(fc1, fc2)
                    for name, param in model.named_parameters():
                        print(name, param)

        """
        params_set = set()
        named_sublayers = self.named_sublayers(
            prefix=prefix,
            include_sublayers=include_sublayers,
            include_self=True)
        for layer_prefix, sublayer in named_sublayers:
            params = sublayer._parameters.items()
            for key, param in params:
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, param

    def named_sublayers(self,
                        prefix='',
                        include_sublayers=True,
                        include_self=False,
                        layers_set=None):
        """
        Returns an iterator over all sublayers in the Layer, yielding tuple of name and sublayer.
        The duplicate sublayer will only be yielded once.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_sublayers(bool, optional): Whether include the sublayers. Default: True.
            include_self(bool, optional): Whether include the Layer itself. Default: False.
            layers_set(set, optioanl): The set to record duplicate sublayers. Default: None.

        Yields:
            (string, Layer): Tuple of name and Layer

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid

                with fluid.dygraph.guard():
                    fc1 = fluid.Linear(10, 3)
                    fc2 = fluid.Linear(3, 10, bias_attr=False)
                    model = fluid.dygraph.Sequential(fc1, fc2)
                    for prefix, layer in model.named_sublayers():
                        print(prefix, layer)

        """
        if layers_set is None:
            layers_set = set()
        if include_self and self not in layers_set:
            layers_set.add(self)
            yield prefix, self
        if include_sublayers:
            for key, layer in self._sub_layers.items():
                if layer is None:
                    continue
                layer_prefix = prefix + ('.' if prefix else '') + key
                for p, l in layer.named_sublayers(
                        prefix=layer_prefix,
                        include_sublayers=include_sublayers,
                        include_self=True,
                        layers_set=layers_set):
                    yield p, l

X
Xin Pan 已提交
417
    def clear_gradients(self):
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
        """
        Clear the gradients of all parameters for this layer.
        
        Returns:
            None
        
        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                with fluid.dygraph.guard():
                    value = np.arange(26).reshape(2, 13).astype("float32")
                    a = fluid.dygraph.to_variable(value)
                    linear = fluid.Linear(13, 5, dtype="float32")
                    adam = fluid.optimizer.Adam(learning_rate=0.01, 
                                                parameter_list=linear.parameters())
                    out = linear(a)
                    out.backward()
                    adam.minimize(out)
                    linear.clear_gradients()

        """
X
Xin Pan 已提交
442
        for p in self.parameters():
443 444
            if p.trainable:
                p.clear_gradient()
X
Xin Pan 已提交
445

446
    def _build_once(self, *args, **kwargs):
447 448
        pass

449
    def __call__(self, *inputs, **kwargs):
450 451 452 453 454 455 456
        for forward_pre_hook in self._forward_pre_hooks.values():
            hook_result = forward_pre_hook(self, inputs)
            if hook_result is not None:
                if not isinstance(hook_result, tuple):
                    hook_result = (hook_result, )
                inputs = hook_result

X
Xin Pan 已提交
457
        if not self._built:
458 459 460 461 462
            with program_desc_tracing_guard(False):
                self._build_once(*inputs, **kwargs)
                if parallel_helper._is_data_parallel_mode():
                    parallel_helper._broadcast_parameters(
                        self._parameters.values())
463
            self._built = True
464

465 466
        with param_guard(self._parameters):
            outputs = self.forward(*inputs, **kwargs)
467 468 469 470 471 472

        for forward_post_hook in self._forward_post_hooks.values():
            hook_result = forward_post_hook(self, inputs, outputs)
            if hook_result is not None:
                outputs = hook_result

M
minqiyang 已提交
473
        return outputs
M
minqiyang 已提交
474

475
    def forward(self, *inputs, **kwargs):
476 477 478 479 480 481 482 483
        """
        Defines the computation performed at every call.
        Should be overridden by all subclasses.

        Parameters:
            *inputs(tuple): unpacked tuple arguments
            **kwargs(dict): unpacked dict arguments
        """
484
        raise NotImplementedError
X
Xin Pan 已提交
485 486 487 488

    def backward(self, *inputs):
        raise ValueError("Layer shouldn't implement backward")

X
Xin Pan 已提交
489 490 491
    def add_sublayer(self, name, sublayer):
        """Adds a sub Layer instance.

492
        Added sublayer can be accessed by self.name
X
Xin Pan 已提交
493

494 495 496
        Parameters:
            name(str): name of this sublayer.
            sublayer(Layer): an instance of Layer.
X
Xin Pan 已提交
497
        Returns:
498
            Layer: the sublayer passed in.
X
Xin Pan 已提交
499 500
        """
        assert isinstance(sublayer, core.Layer)
501

X
Xin Pan 已提交
502 503 504 505 506 507
        self._sub_layers[name] = sublayer
        return sublayer

    def add_parameter(self, name, parameter):
        """Adds a Parameter instance.

508
        Added parameter can be accessed by self.name
X
Xin Pan 已提交
509

510 511 512
        Parameters:
            name(str): name of this sublayer.
            parameter(Parameter): an instance of Parameter.
X
Xin Pan 已提交
513
        Returns:
514
            Parameter: the parameter passed in.
X
Xin Pan 已提交
515
        """
516 517 518 519 520 521
        if parameter is None:
            self._parameters[name] = None
        elif not isinstance(parameter, framework.Parameter):
            raise TypeError(
                "parameter assignment requires Parameter or None, but got '{}'"
                .format(type(parameter).__name__))
522

H
hong 已提交
523 524 525 526 527
        if len(self._loaddict_holder) > 0:
            assert parameter.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format(
                parameter.name)

            parameter.set_value(self._loaddict_holder[parameter.name])
528 529

        self._parameters[name] = parameter
X
Xin Pan 已提交
530 531
        return parameter

X
Xin Pan 已提交
532 533 534 535 536
    def __getattr__(self, name):
        if name in self._parameters:
            return self._parameters[name]
        elif name in self._sub_layers:
            return self._sub_layers[name]
537 538
        else:
            return object.__getattribute__(self, name)
X
Xin Pan 已提交
539 540

    def __setattr__(self, name, value):
S
songyouwei 已提交
541 542 543 544 545
        def _remove_if_exist(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

546 547
        if isinstance(getattr(type(self), name, None), property):
            object.__setattr__(self, name, value)
548
        params = self.__dict__.get('_parameters', None)
X
Xin Pan 已提交
549 550 551 552
        if isinstance(value, framework.Parameter):
            if params is None:
                raise ValueError(
                    "super(YourLayer, self).__init__() should be called first")
H
hong 已提交
553 554 555 556 557 558
            if len(self._loaddict_holder) > 0:
                assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format(
                    value.name)

                value.set_value(self._loaddict_holder[value.name])

S
songyouwei 已提交
559
            _remove_if_exist(self.__dict__, self._sub_layers)
560
            params[name] = value
561 562 563 564 565 566
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(
                    "assignment to parameter '{}' should be of type Parameter or None, but got '{}'"
                    .format(name, type(value).__name__))
            params[name] = None
X
Xin Pan 已提交
567
        else:
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
            layers = self.__dict__.get('_sub_layers', None)
            if isinstance(value, core.Layer):
                if layers is None:
                    raise ValueError(
                        "super(YourLayer, self).__init__() should be called first"
                    )

                _remove_if_exist(self.__dict__, self._parameters)
                layers[name] = value
            elif layers is not None and name in layers:
                if value is not None:
                    raise TypeError(
                        "assignment to sublayer '{}' should be of type Layer or None, but got '{}'"
                        .format(name, type(value).__name__))
                layers[name] = None
            else:
                object.__setattr__(self, name, value)
X
Xin Pan 已提交
585 586 587 588 589 590 591 592 593

    def __delattr__(self, name):
        if name in self._parameters:
            del self._parameters[name]
        elif name in self._sub_layers:
            del self._sub_layers[name]
        else:
            object.__delattr__(self, name)

H
hong 已提交
594 595 596 597
    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
H
hong 已提交
598
        '''
599
        Get all parameters of current layer and its sub-layers. And set all the parameters into a dict
H
hong 已提交
600

601 602 603
        Parameters:
            destination(dict, optional) : If provide, all the parameters will set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
H
hong 已提交
604 605

        Retruns:
606
            dict: a dict contains all the parameters
H
hong 已提交
607 608

        Examples:
609 610
            .. code-block:: python

H
hong 已提交
611 612
                import paddle.fluid as fluid
                with fluid.dygraph.guard():
613
                    emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
614 615 616 617 618 619

                    state_dict = emb.state_dict()
                    fluid.save_dygraph( state_dict, "paddle_dy")

        '''

620 621 622 623
        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
H
hong 已提交
624
                destination[structured_name_prefix + name] = data
625 626 627 628 629 630

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
H
hong 已提交
631 632 633
                        layer_item.state_dict(
                            destination_temp, include_sublayers,
                            structured_name_prefix + layer_name + "."))
634 635 636
                    destination = destination_temp
        return destination

H
hong 已提交
637 638 639 640
    def set_dict(self,
                 stat_dict,
                 include_sublayers=True,
                 use_structured_name=True):
H
hong 已提交
641
        '''
642
        Set parameters from stat_dict. All the parameters will be reset by the tensor in the stat_dict
H
hong 已提交
643

644 645 646
        Parameters:
            state_dict(dict) : Dict contains all the parameters
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
H
hong 已提交
647 648
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key. 
                                                  Default: True
H
hong 已提交
649 650 651 652
        Returns:
            None

        Examples:
653 654
            .. code-block:: python

H
hong 已提交
655 656
                import paddle.fluid as fluid
                with fluid.dygraph.guard():
657
                    emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
658 659 660 661 662 663 664 665 666

                    state_dict = emb.state_dict()
                    fluid.save_dygraph( state_dict, "paddle_dy")
                    
                    para_state_dict, _ = fluid.load_dygraph( "paddle_dy")

                    emb.set_dict( para_state_dict )

        '''
H
hong 已提交
667 668 669 670 671 672 673 674 675
        self.load_dict(
            stat_dict,
            include_sublayers=include_sublayers,
            use_structured_name=use_structured_name)

    def load_dict(self,
                  stat_dict,
                  include_sublayers=True,
                  use_structured_name=True):
H
hong 已提交
676
        '''
677
        Set parameters from stat_dict. All the parameters will be reset by the tensor in the stat_dict
H
hong 已提交
678 679 680

        This api will be Deprecated. Please use set_dict

681 682 683
        Parameters:
            state_dict(dict) : Dict contains all the parameters
            include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
H
hong 已提交
684 685
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key.
                                                  Default: True
H
hong 已提交
686 687 688 689
        Returns:
            None

        Examples:
690 691
            .. code-block:: python

H
hong 已提交
692 693
                import paddle.fluid as fluid
                with fluid.dygraph.guard():
694
                    emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
695 696 697 698 699 700 701 702 703 704

                    state_dict = emb.state_dict()
                    fluid.save_dygraph( state_dict, "paddle_dy")
                    
                    para_state_dict, _ = fluid.load_dygraph( "paddle_dy")

                    emb.load_dict( para_state_dict )

        '''

H
hong 已提交
705 706 707 708 709 710
        inner_state_dict = self.state_dict()

        for name, para in inner_state_dict.items():
            key_name = name if use_structured_name else para.name
            if key_name in stat_dict:
                para.set_value(stat_dict[key_name])
H
hong 已提交
711 712
            else:
                raise RuntimeError(
H
hong 已提交
713 714 715 716 717 718 719 720 721 722 723
                    "Parameter not found, Can't not find [ {} ] in stat_dict"
                    "use_structured_name is set to [{}]".format(
                        key_name, use_structured_name))
        unused_para_list = []
        for k, v in stat_dict.items():
            if k not in inner_state_dict:
                unused_para_list.append(k)
        if len(unused_para_list) > 0:
            warnings.warn(
                "Varibale [ {} ] are not used, because not included in layers state_dict".
                format(" ".join(unused_para_list)))