layers.py 70.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
Xin Pan 已提交
15
import collections
16 17 18
import contextlib
import sys
import numpy as np
19
import re
20 21 22
import copy
import weakref
import warnings
23
from copy import deepcopy
24 25
import inspect

26
import paddle
C
chenjian 已提交
27
import paddle.profiler as profiler
28
from paddle.profiler.utils import in_profiler_mode
29

C
chengduo 已提交
30
from . import parallel_helper
X
Xin Pan 已提交
31
from .. import unique_name
32
from paddle.fluid import core
33
from .layer_object_helper import LayerObjectHelper
34 35 36 37 38 39 40 41 42 43 44
from .layer_hooks import (
    record_program_ops_pre_hook,
    set_op_customized_attrs_post_hook,
    LayerOpsRecoder,
)
from .base import (
    program_desc_tracing_guard,
    param_guard,
    in_declarative_mode,
    _convert_into_variable,
)
45
from paddle.fluid import framework
46
from ..param_attr import ParamAttr
47
from paddle.fluid.executor import Executor, global_scope
48 49 50 51
from paddle.fluid.framework import (
    convert_np_dtype_to_dtype_,
    in_dygraph_mode,
)
52
from paddle.fluid.framework import Program, program_guard
53
from paddle.fluid.framework import _current_expected_place as _get_device
54
from paddle.fluid.core import VarDesc
C
chentianyu03 已提交
55
from paddle.fluid.dygraph import no_grad
W
wanghuancoder 已提交
56
import paddle.utils.deprecated as deprecated
57

58
__all__ = ['Layer']
59

60 61 62 63
_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re = re.compile('([a-z])([A-Z])')


64 65 66 67 68 69
def _scope_dist2single(dist_scope):
    mapping = {
        "row_parallel_linear": "linear",
        "column_parallel_linear": "linear",
        "vocab_parallel_embedding": "embedding",
        # "parallel_cross_entropy": "cross_entropy", while mp_layer has parallel_cross_entropy,
S
Shuangchi He 已提交
70
        # but there is no parameters so the mapping of parallel_cross_entropy is not necessary.
71 72 73 74
    }
    return mapping.get(dist_scope, dist_scope)


75 76 77 78
def _convert_camel_to_snake(name):
    s1 = _first_cap_re.sub(r'\1_\2', name)
    return _all_cap_re.sub(r'\1_\2', s1).lower()

79

80 81 82 83 84 85 86 87 88 89 90
def _addindent(string, indent):
    s1 = string.split('\n')
    if len(s1) == 1:
        return string
    s2 = []
    for idx, line in enumerate(s1):
        if idx > 0:
            s2.append(str((indent * ' ') + line))
    return s1[0] + '\n' + '\n'.join(s2)


91
class HookRemoveHelper:
92
    """A HookRemoveHelper that can be used to remove hook."""
93 94 95 96 97 98 99 100 101 102 103 104 105 106

    next_hook_id = 0

    def __init__(self, hooks):
        self._hooks_ref = weakref.ref(hooks)
        self._hook_id = HookRemoveHelper.next_hook_id
        HookRemoveHelper.next_hook_id += 1

    def remove(self):
        hooks = self._hooks_ref()
        if hooks is not None and self._hook_id in hooks:
            del hooks[self._hook_id]


107
class Layer:
108 109
    """
    Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
X
Xin Pan 已提交
110

111
    Parameters:
112 113
        name_scope (str, optional): prefix name used by the layer to name parameters.
            If prefix is "my_layer", parameter name in MyLayer
114 115 116
            can be "my_layer_0.w_n", where "w" is the parameter
            base name and "n" is an unique suffix auto-generated.
            If None, prefix name will be snake cased class name. Default: None.
117
        dtype(str, optional): data type of this parameter.
118 119
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
120
                Default: "float32"
121

122 123
    Returns:
        None
124 125 126 127 128 129 130

    Examples:
        .. code-block:: python

            import paddle
            class MyLayer(paddle.nn.Layer):
                def __init__(self):
131
                    super().__init__()
132 133 134 135 136 137 138 139 140 141 142 143
                    self._linear = paddle.nn.Linear(1, 1)
                    self._dropout = paddle.nn.Dropout(p=0.5)
                def forward(self, input):
                    temp = self._linear(input)
                    temp = self._dropout(temp)
                    return temp
            x = paddle.randn([10, 1], 'float32')
            mylayer = MyLayer()
            mylayer.eval()  # set mylayer._dropout to eval mode
            out = mylayer(x)
            mylayer.train()  # set mylayer._dropout to train mode
            out = mylayer(x)
X
Xin Pan 已提交
144
    """
X
Xin Pan 已提交
145

146
    def __init__(self, name_scope=None, dtype="float32"):
147
        self.training = True
148
        if name_scope is None:
149
            name_scope = _convert_camel_to_snake(self.__class__.__name__)
150
            name_scope = _scope_dist2single(name_scope)
151
        self._full_name = unique_name.generate(name_scope)
152
        self._helper = LayerObjectHelper(self._full_name)
X
Xin Pan 已提交
153
        self._built = False
M
minqiyang 已提交
154
        self._dtype = dtype
姜永久 已提交
155
        self._init_in_dynamic_mode = in_dygraph_mode()
156

X
Xin Pan 已提交
157
        self._parameters = collections.OrderedDict()
158 159 160
        # Buffers the variable (not parameter) created in layer
        self._buffers = collections.OrderedDict()
        self._non_persistable_buffer_names_set = set()
X
Xin Pan 已提交
161
        self._sub_layers = collections.OrderedDict()
L
lujun 已提交
162
        self._loaddict_holder = collections.OrderedDict()
163

164 165 166 167
        # Record generated op_descs in this layer
        self._op_recorder = LayerOpsRecoder(ops=[], hooks=[])
        self._customized_attrs = {}

168 169 170
        self._forward_pre_hooks = collections.OrderedDict()
        self._forward_post_hooks = collections.OrderedDict()

171 172 173
        self._casted_by_pure_fp16 = False

        self._state_dict_hooks = collections.OrderedDict()
174 175
        # Records orignal functions after @to_static to support to rollback
        self._original_funcs = collections.OrderedDict()
176

M
minqiyang 已提交
177
    def train(self):
178
        """
U
ustiniankw 已提交
179

180 181 182 183 184
        Sets this Layer and all its sublayers to training mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
185

U
ustiniankw 已提交
186
        Examples:
187 188 189 190 191 192
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
193
                        super().__init__()
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                mylayer.train()  # set mylayer._dropout to train mode
                out = mylayer(x)

209
        """
210 211 212
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
姜永久 已提交
213
        if in_dygraph_mode():
214
            framework._dygraph_tracer().train_mode()
215 216 217
        # Layer-level setting
        self.training = True
        for layer in self.sublayers():
218
            layer.training = True
M
minqiyang 已提交
219 220

    def eval(self):
221 222 223 224 225 226
        """
        Sets this Layer and all its sublayers to evaluation mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
227 228 229 230 231 232 233 234

        Example::
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
235
                        super().__init__()
236 237 238 239 240 241 242 243 244 245 246 247 248 249
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                print(out)

250
        """
251 252 253
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
姜永久 已提交
254
        if in_dygraph_mode():
255
            framework._dygraph_tracer().eval_mode()
256 257 258
        # Layer-level setting
        self.training = False
        for layer in self.sublayers():
259
            layer.training = False
M
minqiyang 已提交
260

L
LielinJiang 已提交
261 262
    def apply(self, fn):
        """
U
ustiniankw 已提交
263

L
LielinJiang 已提交
264 265 266 267 268 269 270
        Applies ``fn`` recursively to every sublayer (as returned by ``.sublayers()``)
        as well as self. Typical use includes initializing the parameters of a model.

        Parameters:
            fn (function): a function to be applied to each sublayer

        Returns:
U
ustiniankw 已提交
271
            Layer, self
L
LielinJiang 已提交
272 273 274 275 276 277

        Example::
            .. code-block:: python

              import paddle
              import paddle.nn as nn
278

L
LielinJiang 已提交
279 280 281 282 283
              net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

              def init_weights(layer):
                  if type(layer) == nn.Linear:
                      print('before init weight:', layer.weight.numpy())
284
                      new_weight = paddle.full(shape=layer.weight.shape, dtype=layer.weight.dtype, fill_value=0.9)
L
LielinJiang 已提交
285 286 287 288 289 290
                      layer.weight.set_value(new_weight)
                      print('after init weight:', layer.weight.numpy())

              net.apply(init_weights)

              print(net.state_dict())
U
ustiniankw 已提交
291

L
LielinJiang 已提交
292
        """
293
        for layer in self.children():
L
LielinJiang 已提交
294 295 296 297 298 299
            layer.apply(fn)

        fn(self)

        return self

X
Xin Pan 已提交
300
    def full_name(self):
U
ustiniankw 已提交
301 302 303
        """

        Full name for this layer, composed by name_scope + "/" + MyLayer.__class__.__name__
X
Xin Pan 已提交
304

305
        Returns:
U
ustiniankw 已提交
306
            str, full name of this layer.
307 308 309 310 311 312 313 314

        Example::
            .. code-block:: python

                import paddle

                class LinearNet(paddle.nn.Layer):
                    def __init__(self):
315
                        super().__init__(name_scope = "demo_linear_net")
316 317 318 319 320 321 322 323
                        self._linear = paddle.nn.Linear(1, 1)

                    def forward(self, x):
                        return self._linear(x)

                linear_net = LinearNet()
                print(linear_net.full_name())   # demo_linear_net_0

X
Xin Pan 已提交
324 325 326
        """
        return self._full_name

327
    def register_forward_post_hook(self, hook):
U
ustiniankw 已提交
328 329 330
        """

        Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed.
331 332 333

        It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively.
        User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.
334

335 336 337 338 339 340
        hook(Layer, input, output) -> None or modified output

        Parameters:
            hook(function): a function registered as a forward post-hook

        Returns:
U
ustiniankw 已提交
341
            HookRemoveHelper, a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
342 343 344 345

        Examples:
            .. code-block:: python

346 347 348 349 350 351
                import paddle
                import numpy as np

                # the forward_post_hook change the output of the layer: output = output * 2
                def forward_post_hook(layer, input, output):
                    # user can use layer, input and output for information statistis tasks
352

353 354
                    # change the output
                    return output * 2
355

356
                linear = paddle.nn.Linear(13, 5)
357

358 359
                # register the hook
                forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
360

361 362
                value1 = np.arange(26).reshape(2, 13).astype("float32")
                in1 = paddle.to_tensor(value1)
363

364
                out0 = linear(in1)
365

366 367 368 369 370 371 372
                # remove the hook
                forward_post_hook_handle.remove()

                out1 = linear(in1)

                # hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
                assert (out0.numpy() == (out1.numpy()) * 2).any()
U
ustiniankw 已提交
373

374 375 376 377 378 379
        """
        hook_remove_helper = HookRemoveHelper(self._forward_post_hooks)
        self._forward_post_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

    def register_forward_pre_hook(self, hook):
U
ustiniankw 已提交
380 381 382
        """

        Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed.
383

384
        It should have the following form, `input` of the `hook` is `input` of the `Layer`,
385
        hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if
386 387 388 389 390 391 392 393 394
        a single value is returned(unless that value is already a tuple).
        User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.

        hook(Layer, input) -> None or modified input

        Parameters:
            hook(function): a function registered as a forward pre-hook

        Returns:
U
ustiniankw 已提交
395
            HookRemoveHelper, a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
396 397 398 399

        Examples:
            .. code-block:: python

400 401
                import paddle
                import numpy as np
402

403
                # the forward_pre_hook change the input of the layer: input = input * 2
404 405
                def forward_pre_hook(layer, input):
                    # user can use layer and input for information statistis tasks
406

407 408 409
                    # change the input
                    input_return = (input[0] * 2)
                    return input_return
410

411
                linear = paddle.nn.Linear(13, 5)
412

413 414
                # register the hook
                forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)
415

416 417 418
                value0 = np.arange(26).reshape(2, 13).astype("float32")
                in0 = paddle.to_tensor(value0)
                out0 = linear(in0)
419

420 421
                # remove the hook
                forward_pre_hook_handle.remove()
422

423 424 425
                value1 = value0 * 2
                in1 = paddle.to_tensor(value1)
                out1 = linear(in1)
426

427 428
                # hook change the linear's input to input * 2, so out0 is equal to out1.
                assert (out0.numpy() == out1.numpy()).any()
429 430 431 432 433
        """
        hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
        self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

434 435 436 437 438 439 440 441
    def create_parameter(
        self,
        shape,
        attr=None,
        dtype=None,
        is_bias=False,
        default_initializer=None,
    ):
442
        """Create parameters for this layer.
443

444
        Parameters:
445
            shape(list): Shape of the parameter.
446 447
            attr(ParamAttr, optional): Parameter attribute of weight. Please refer to :ref:`api_paddle_ParamAttr`. Default: None.
            dtype(str, optional): Data type of this parameter.
448
                If set str, it can be "bool",  "float16", "float32", "float64",
449 450
                "int8", "int16", "int32", "int64", "uint8" or "uint16". Default: "float32".
            is_bias(bool, optional): if this is a bias parameter. Default: False.
451
            default_initializer(Initializer, optional): the default initializer for this parameter.
452
                If set None, default initializer will be set to paddle.nn.initializer.Xavier and paddle.nn.initializer.Constant
453
                for non-bias and bias parameter, respectively. Default: None.
454

455
        Returns:
456 457 458 459 460 461 462 463 464
            :Tensor, created parameter.

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
465
                        super().__init__()
466 467 468 469 470 471 472 473 474 475 476
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

477
        """
H
hong 已提交
478
        temp_attr = copy.deepcopy(attr)
479
        if isinstance(temp_attr, str) and temp_attr == "":
H
hong 已提交
480
            temp_attr = None
481 482 483 484 485 486 487 488 489
        return self._helper.create_parameter(
            temp_attr, shape, dtype, is_bias, default_initializer
        )

    @deprecated(
        since="2.0.0",
        update_to="paddle.nn.Layer.create_tensor",
        reason="New api in create_tensor, easier to use.",
    )
490
    def create_variable(self, name=None, persistable=None, dtype=None):
W
wanghuancoder 已提交
491 492 493
        """

        Create Tensor for this layer.
494

495
        Parameters:
W
wanghuancoder 已提交
496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None

            persistable(bool, optional): if set this tensor persistable. Default: False

            dtype(str, optional): data type of this parameter. If set str, it can be "bool", "float16", "float32", "float64","int8", "int16", "int32", "int64", "uint8" or "uint16". If set None, it will be "float32". Default: None

        Returns:
            Tensor, created Tensor.

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
514
                        super().__init__()
W
wanghuancoder 已提交
515
                        self.linear = paddle.nn.Linear( 10, 10)
516

W
wanghuancoder 已提交
517
                        self.back_var = self.create_variable(name = "linear_tmp_0", dtype=self._dtype)
518

W
wanghuancoder 已提交
519 520 521
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
522

W
wanghuancoder 已提交
523 524 525 526 527 528
                        return out

        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
529 530 531
            var_name = unique_name.generate(
                ".".join([self._full_name, "_generated_var"])
            )
W
wanghuancoder 已提交
532 533 534 535 536

        return self._helper.main_program.current_block().create_var(
            name=var_name,
            persistable=persistable,
            dtype=dtype,
537 538
            type=core.VarDesc.VarType.LOD_TENSOR,
        )
W
wanghuancoder 已提交
539 540 541 542 543 544 545 546 547 548

    # TODO: Add more parameter list when we need them
    def create_tensor(self, name=None, persistable=None, dtype=None):
        """

        Create Tensor for this layer.

        Parameters:
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None
            persistable(bool, optional): if set this tensor persistable. Default: False
549
            dtype(str, optional): data type of this parameter.
550 551
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
552
                If set None, it will be "float32". Default: None
553

554
        Returns:
W
wanghuancoder 已提交
555
            Tensor, created Tensor.
556 557 558 559 560 561 562 563 564 565

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
566
                        super().__init__()
567
                        self.linear = paddle.nn.Linear( 10, 10)
568

W
wanghuancoder 已提交
569
                        self.back_var = self.create_tensor(name = "linear_tmp_0", dtype=self._dtype)
570

571 572 573
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
574

575 576
                        return out

577 578 579 580
        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
581 582 583
            var_name = unique_name.generate(
                ".".join([self._full_name, "_generated_var"])
            )
584 585

        return self._helper.main_program.current_block().create_var(
586 587 588
            name=var_name,
            persistable=persistable,
            dtype=dtype,
589 590
            type=core.VarDesc.VarType.LOD_TENSOR,
        )
591

X
polish  
Xin Pan 已提交
592
    def parameters(self, include_sublayers=True):
U
ustiniankw 已提交
593 594 595
        """

        Returns a list of all Parameters from current layer and its sub-layers.
X
Xin Pan 已提交
596

597
        Returns:
U
ustiniankw 已提交
598
            list of Tensor, a list of Parameters.
599 600 601 602

        Examples:
            .. code-block:: python

U
ustiniankw 已提交
603
                import paddle
604

U
ustiniankw 已提交
605 606
                linear = paddle.nn.Linear(1,1)
                print(linear.parameters())  # print linear_0.w_0 and linear_0.b_0
607

X
Xin Pan 已提交
608
        """
609
        ret = [
610 611 612 613
            param
            for _, param in self.named_parameters(
                include_sublayers=include_sublayers
            )
614
        ]
X
polish  
Xin Pan 已提交
615
        return ret
X
Xin Pan 已提交
616

617
    def children(self):
U
ustiniankw 已提交
618 619 620
        """

        Returns an iterator over immediate children layers.
621 622 623 624 625 626 627

        Yields:
            Layer: a child layer

        Examples:
            .. code-block:: python

628
                import paddle
629

630 631 632 633 634
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)

                layer_list = list(model.children())
635

636
                print(layer_list)   # [<paddle.nn.layer.common.Linear object at 0x7f7b8113f830>, <paddle.nn.layer.common.Linear object at 0x7f7b8113f950>]
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651

        """
        for _, layer in self.named_children():
            yield layer

    def named_children(self):
        """Returns an iterator over immediate children layers, yielding both
        the name of the layer as well as the layer itself.

        Yields:
            (string, Layer): Tuple containing a name and child layer

        Examples:
            .. code-block:: python

652
                import paddle
653

654 655 656 657 658 659 660
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)
                for prefix, layer in model.named_children():
                    print(prefix, layer)
                    # ('0', <paddle.nn.layer.common.Linear object at 0x7fb61ed85830>)
                    # ('1', <paddle.nn.layer.common.Linear object at 0x7fb61ed85950>)
661 662 663 664 665 666 667 668

        """
        memo = set()
        for name, layer in self._sub_layers.items():
            if layer is not None and layer not in memo:
                memo.add(layer)
                yield name, layer

J
Jiabin Yang 已提交
669
    def sublayers(self, include_self=False):
U
ustiniankw 已提交
670 671 672
        """

        Returns a list of sub layers.
X
Xin Pan 已提交
673

674
        Parameters:
J
Jiabin Yang 已提交
675
            include_self(bool, optional): Whether return self as sublayers. Default: False
X
Xin Pan 已提交
676

677
        Returns:
U
ustiniankw 已提交
678
            list of Layer, a list of sub layers.
679 680 681 682 683 684 685 686

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
687
                        super().__init__()
688 689 690 691 692 693 694 695 696 697 698
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                mylayer = MyLayer()
                print(mylayer.sublayers())  # [<paddle.nn.layer.common.Linear object at 0x7f44b58977d0>, <paddle.nn.layer.common.Dropout object at 0x7f44b58978f0>]

X
Xin Pan 已提交
699
        """
700 701
        ret = [
            layer
J
Jiabin Yang 已提交
702
            for _, layer in self.named_sublayers(include_self=include_self)
703
        ]
X
Xin Pan 已提交
704 705
        return ret

706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
    def named_parameters(self, prefix='', include_sublayers=True):
        """
        Returns an iterator over all parameters in the Layer, yielding tuple of name and parameter.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_sublayers(bool, optional): Whether include the parameters of sublayers.
                If True, also include the named parameters from sublayers. Default: True.

        Yields:
            (string, Parameter): Tuple of name and Parameter

        Examples:
            .. code-block:: python

721
                import paddle
722

723 724 725 726 727
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for name, param in model.named_parameters():
                    print(name, param)
728 729 730

        """
        params_set = set()
731 732 733 734 735
        named_sublayers = (
            self.named_sublayers(prefix=prefix, include_self=True)
            if include_sublayers
            else zip([prefix], [self])
        )
736 737 738 739 740 741 742 743 744
        for layer_prefix, sublayer in named_sublayers:
            params = sublayer._parameters.items()
            for key, param in params:
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, param

J
Jiabin Yang 已提交
745
    def named_sublayers(self, prefix='', include_self=False, layers_set=None):
746 747 748 749 750 751 752
        """
        Returns an iterator over all sublayers in the Layer, yielding tuple of name and sublayer.
        The duplicate sublayer will only be yielded once.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_self(bool, optional): Whether include the Layer itself. Default: False.
753
            layers_set(set, optional): The set to record duplicate sublayers. Default: None.
754 755 756 757 758 759 760

        Yields:
            (string, Layer): Tuple of name and Layer

        Examples:
            .. code-block:: python

761
                import paddle
762

763 764 765 766 767
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
768 769 770 771 772 773 774

        """
        if layers_set is None:
            layers_set = set()
        if include_self and self not in layers_set:
            layers_set.add(self)
            yield prefix, self
J
Jiabin Yang 已提交
775 776 777 778
        for key, layer in self._sub_layers.items():
            if layer is None:
                continue
            layer_prefix = prefix + ('.' if prefix else '') + key
779 780 781
            for p, l in layer.named_sublayers(
                prefix=layer_prefix, include_self=True, layers_set=layers_set
            ):
J
Jiabin Yang 已提交
782
                yield p, l
783

784
    def register_buffer(self, name, tensor, persistable=True):
785
        """
786
        Registers a tensor as buffer into the layer.
787

788
        `buffer` is a non-trainable tensor and will not be updated by optimizer,
789 790 791 792 793 794 795 796 797 798
        but is necessary for evaluation and inference. For example, the mean and variance in BatchNorm layers.
        The registered buffer is persistable by default, and will be saved into
        `state_dict` alongside parameters. If set persistable=False, it registers
        a non-persistable buffer, so that it will not be a part of `state_dict` .

        Buffers can be accessed as attributes using given names.

        Parameters:
            name (string): name of the buffer. The buffer can be accessed
                from this layer using the given name
799
            tensor (Tensor): the tensor to be registered as buffer.
800 801 802 803 804
            persistable (bool): whether the buffer is part of this layer's
                state_dict.

        Returns:
            None
805

806 807 808 809
        Examples:
            .. code-block:: python

                import numpy as np
810
                import paddle
811

812 813 814 815 816 817 818
                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                # get the buffer by attribute.
                print(linear.buf_name)
819 820 821 822

        """

        if '_buffers' not in self.__dict__:
823
            raise ValueError("super().__init__() should be called first")
824
        elif not isinstance(name, str):
825
            raise TypeError(
826 827 828 829
                "The name of buffer should be a string, but received {}.".format(
                    type(name).__name__
                )
            )
830
        elif '.' in name:
831 832 833
            raise KeyError(
                "The name of buffer can not contain `.`, "
                "because when you access the newly added buffer in the "
834 835
                "form of `self.**.**`, it will cause AttributeError."
            )
836 837 838 839
        elif name == '':
            raise KeyError("The name of buffer can not be empty.")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists.".format(name))
840 841 842
        elif tensor is not None and not (
            type(tensor) == core.VarBase or type(tensor) == core.eager.Tensor
        ):
843
            raise TypeError(
844 845 846 847
                "The registered buffer should be a Paddle.Tensor, but received {}.".format(
                    type(tensor).__name__
                )
            )
848
        else:
849
            self._buffers[name] = tensor
850 851 852 853 854 855 856
            if persistable:
                self._non_persistable_buffer_names_set.discard(name)
            else:
                self._non_persistable_buffer_names_set.add(name)

    def buffers(self, include_sublayers=True):
        """
U
ustiniankw 已提交
857

858 859 860 861 862 863
        Returns a list of all buffers from current layer and its sub-layers.

        Parameters:
            include_sublayers(bool, optional): Whether include the buffers of sublayers. If True, also include the buffers from sublayers. Default: True

        Returns:
U
ustiniankw 已提交
864
            list of Tensor, a list of buffers.
865 866 867 868 869 870 871 872 873 874 875 876 877 878

        Examples:
            .. code-block:: python

                import numpy as np
                import paddle

                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                print(linear.buffers())     # == print([linear.buf_name])

879 880
        """
        ret = [
881 882 883 884
            buffer
            for _, buffer in self.named_buffers(
                include_sublayers=include_sublayers
            )
885 886 887 888 889
        ]
        return ret

    def named_buffers(self, prefix='', include_sublayers=True):
        """
890
        Returns an iterator over all buffers in the Layer, yielding tuple of name and Tensor.
891 892 893 894 895 896 897

        Parameters:
            prefix(str, optional): Prefix to prepend to all buffer names. Default: ''.
            include_sublayers(bool, optional): Whether include the buffers of sublayers.
                If True, also include the named buffers from sublayers. Default: True.

        Yields:
898
            (string, Tensor): Tuple of name and tensor
899 900 901 902 903

        Examples:
            .. code-block:: python

                import numpy as np
904
                import paddle
905

906 907 908 909
                fc1 = paddle.nn.Linear(10, 3)
                buffer1 = paddle.to_tensor(np.array([0]).astype("float32"))
                # register a tensor as buffer by specific `persistable`
                fc1.register_buffer("buf_name_1", buffer1, persistable=True)
910

911 912 913 914 915
                fc2 = paddle.nn.Linear(3, 10)
                buffer2 = paddle.to_tensor(np.array([1]).astype("float32"))
                # register a buffer by assigning an attribute with Tensor.
                # The `persistable` can only be False by this way.
                fc2.buf_name_2 = buffer2
916

917
                model = paddle.nn.Sequential(fc1, fc2)
918

919 920 921
                # get all named buffers
                for name, buffer in model.named_buffers():
                    print(name, buffer)
922 923 924

        """
        buffers_set = set()
925 926 927 928 929
        named_sublayers = (
            self.named_sublayers(prefix=prefix, include_self=True)
            if include_sublayers
            else zip([prefix], [self])
        )
930 931 932 933 934 935 936 937 938
        for layer_prefix, sublayer in named_sublayers:
            buffers = sublayer._buffers.items()
            for key, buffer in buffers:
                if buffer is None or buffer in buffers_set:
                    continue
                buffers_set.add(buffer)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, buffer

X
Xin Pan 已提交
939
    def clear_gradients(self):
940 941
        """
        Clear the gradients of all parameters for this layer.
942

943 944
        Returns:
            None
945

946 947 948
        Examples:
            .. code-block:: python

949
                import paddle
950 951
                import numpy as np

952 953 954 955 956 957 958 959 960
                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.to_tensor(value)
                linear = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01,
                                            parameters=linear.parameters())
                out = linear(a)
                out.backward()
                adam.step()
                linear.clear_gradients()
961 962

        """
X
Xin Pan 已提交
963
        for p in self.parameters():
964 965
            if p.trainable:
                p.clear_gradient()
X
Xin Pan 已提交
966

967
    def _build_once(self, *args, **kwargs):
968 969
        pass

970 971 972 973 974
    def _dygraph_call_func(self, *inputs, **kwargs):
        for forward_pre_hook in self._forward_pre_hooks.values():
            hook_result = forward_pre_hook(self, inputs)
            if hook_result is not None:
                if not isinstance(hook_result, tuple):
975
                    hook_result = (hook_result,)
976 977 978 979 980 981 982 983 984
                inputs = hook_result

        if not self._built:
            with program_desc_tracing_guard(False):
                self._build_once(*inputs, **kwargs)

                # TODO(liuyuhui) Only xpu broadcast parameters here.
                # The other device is to call _sync_params_buffers in DataParallel
                # to realize the parameter synchronization among multiply cards.
985 986 987 988
                if (
                    parallel_helper._is_data_parallel_mode()
                    and paddle.is_compiled_with_xpu()
                ):
989
                    parallel_helper._broadcast_parameters(
990 991
                        self._parameters.values()
                    )
992 993 994

            self._built = True

995
        if in_profiler_mode():
996 997 998
            with profiler.RecordEvent(
                self.__class__.__name__, profiler.TracerEventType.Forward
            ):
999 1000
                outputs = self.forward(*inputs, **kwargs)
        else:
C
chenjian 已提交
1001
            outputs = self.forward(*inputs, **kwargs)
1002 1003 1004 1005 1006 1007 1008 1009

        for forward_post_hook in self._forward_post_hooks.values():
            hook_result = forward_post_hook(self, inputs, outputs)
            if hook_result is not None:
                outputs = hook_result

        return outputs

1010
    def __call__(self, *inputs, **kwargs):
1011 1012 1013 1014 1015 1016 1017 1018
        if (
            (not in_declarative_mode())
            and (not self._forward_pre_hooks)
            and (not self._forward_post_hooks)
            and (not self._built)
            and in_dygraph_mode()
            and (not in_profiler_mode())
        ):
1019 1020 1021 1022
            self._build_once(*inputs, **kwargs)
            return self.forward(*inputs, **kwargs)
        else:
            return self._dygraph_call_func(*inputs, **kwargs)
M
minqiyang 已提交
1023

1024
    def forward(self, *inputs, **kwargs):
1025 1026 1027 1028 1029 1030 1031 1032
        """
        Defines the computation performed at every call.
        Should be overridden by all subclasses.

        Parameters:
            *inputs(tuple): unpacked tuple arguments
            **kwargs(dict): unpacked dict arguments
        """
1033
        raise NotImplementedError
X
Xin Pan 已提交
1034 1035 1036 1037

    def backward(self, *inputs):
        raise ValueError("Layer shouldn't implement backward")

X
Xin Pan 已提交
1038
    def add_sublayer(self, name, sublayer):
U
ustiniankw 已提交
1039 1040 1041
        """

        Adds a sub Layer instance.
X
Xin Pan 已提交
1042

1043
        Added sublayer can be accessed by self.name
X
Xin Pan 已提交
1044

1045 1046 1047
        Parameters:
            name(str): name of this sublayer.
            sublayer(Layer): an instance of Layer.
X
Xin Pan 已提交
1048
        Returns:
U
ustiniankw 已提交
1049
            Layer, the sublayer passed in.
1050

1051 1052 1053 1054 1055 1056 1057
        Examples:
            .. code-block:: python

                import paddle

                class MySequential(paddle.nn.Layer):
                    def __init__(self, *layers):
1058
                        super().__init__()
1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
                        if len(layers) > 0 and isinstance(layers[0], tuple):
                            for name, layer in layers:
                                self.add_sublayer(name, layer)
                        else:
                            for idx, layer in enumerate(layers):
                                self.add_sublayer(str(idx), layer)

                    def forward(self, input):
                        for layer in self._sub_layers.values():
                            input = layer(input)
                        return input

                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = MySequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
U
ustiniankw 已提交
1076

X
Xin Pan 已提交
1077
        """
1078
        assert isinstance(sublayer, Layer) or sublayer is None
1079

X
Xin Pan 已提交
1080 1081 1082 1083 1084 1085
        self._sub_layers[name] = sublayer
        return sublayer

    def add_parameter(self, name, parameter):
        """Adds a Parameter instance.

1086
        Added parameter can be accessed by self.name
X
Xin Pan 已提交
1087

1088 1089 1090
        Parameters:
            name(str): name of this sublayer.
            parameter(Parameter): an instance of Parameter.
X
Xin Pan 已提交
1091
        Returns:
U
ustiniankw 已提交
1092
            Parameter, the parameter passed in.
1093 1094 1095 1096 1097 1098 1099
        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
1100
                        super().__init__()
1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

X
Xin Pan 已提交
1112
        """
1113
        if '_parameters' not in self.__dict__:
1114
            raise RuntimeError("super().__init__() should be called firstly.")
1115
        elif not isinstance(name, str):
1116
            raise TypeError(
1117 1118 1119 1120
                "The name of parameter should be a string, but received {}.".format(
                    type(name).__name__
                )
            )
1121 1122 1123 1124
        elif '.' in name:
            raise KeyError(
                "The name of parameter can not contain `.`, "
                "because when you access the newly added parameter in the "
1125 1126
                "form of `self.**.**`, it will cause AttributeError."
            )
1127 1128 1129 1130
        elif name == '':
            raise KeyError("The name of parameter can not be empty.")
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("The parameter '{}' already exists.".format(name))
1131 1132 1133
        elif parameter is not None and not isinstance(
            parameter, framework.Parameter
        ):
1134
            raise TypeError(
1135 1136 1137 1138
                "The parameter to be added should be a Parameter, but received {}.".format(
                    type(parameter).__name__
                )
            )
1139 1140 1141
        else:
            if parameter is None:
                self._parameters[name] = None
1142

1143
            if len(self._loaddict_holder) > 0:
1144 1145 1146 1147 1148
                assert (
                    parameter.name in self._loaddict_holder
                ), "Parameter not found, Can't not find [ {} ] in state_dict".format(
                    parameter.name
                )
H
hong 已提交
1149

1150
                parameter.set_value(self._loaddict_holder[parameter.name])
1151

1152
            self._parameters[name] = parameter
X
Xin Pan 已提交
1153 1154
        return parameter

1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166
    def _set_op_attrs(self, attrs):
        """
        Add customized attribute while append_op. In case of quantization, we want to save
        some attributes into op_desc while exporting inference model by @to_static.

        Arguments:
            attrs(dict): customized attributes that will be added into op_descs.

        NOTE: The interface is only exposed to developers.
        """

        def is_already_registered(is_pre_hook):
1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
            layers_hooks = (
                self._forward_pre_hooks
                if is_pre_hook
                else self._forward_post_hooks
            )
            candidate_hook = (
                record_program_ops_pre_hook
                if is_pre_hook
                else set_op_customized_attrs_post_hook
            )
1177 1178 1179 1180

            already_registed = False
            if layers_hooks:
                last_key = next(reversed(layers_hooks))
1181
                already_registed = layers_hooks[last_key] == candidate_hook
1182 1183 1184 1185

            return already_registed

        if not isinstance(attrs, dict):
1186 1187
            raise TypeError(
                "attrs should be type(dict), but received {}".format(
1188 1189 1190
                    type(attrs).__name__
                )
            )
1191 1192 1193 1194 1195 1196

        # NOTE: Overwrite behavior for same key.
        self._customized_attrs.update(attrs)

        if not is_already_registered(is_pre_hook=True):
            pre_hook_helper = self.register_forward_pre_hook(
1197 1198
                record_program_ops_pre_hook
            )
1199 1200 1201 1202 1203 1204
            assert len(self._op_recorder.hooks) == 0
            self._op_recorder.hooks = [pre_hook_helper]

        # manually register post_hook to ensure it is inserted into the head.
        if not is_already_registered(is_pre_hook=False):
            post_hook_helper = self.register_forward_post_hook(
1205 1206
                set_op_customized_attrs_post_hook
            )
1207
            if len(self._forward_post_hooks) > 1:
1208 1209 1210
                self._forward_post_hooks.move_to_end(
                    post_hook_helper._hook_id, last=False
                )
1211 1212 1213 1214 1215 1216

            assert len(self._op_recorder.hooks) == 1

            # hooks that need to be removed once we finish executing them.
            self._op_recorder.hooks.append(post_hook_helper)

1217 1218 1219 1220 1221 1222
    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)

X
Xin Pan 已提交
1223
    def __getattr__(self, name):
1224 1225 1226
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in self._parameters:
1227
                if in_declarative_mode():
1228
                    return _convert_into_variable(self._parameters[name])
1229 1230 1231 1232 1233 1234 1235 1236
                return self._parameters[name]
        if '_sub_layers' in self.__dict__:
            _sub_layers = self.__dict__['_sub_layers']
            if name in self._sub_layers:
                return self._sub_layers[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
1237
                if in_declarative_mode():
1238
                    return _convert_into_variable(_buffers[name])
1239 1240
                return _buffers[name]
        return object.__getattribute__(self, name)
X
Xin Pan 已提交
1241 1242

    def __setattr__(self, name, value):
S
songyouwei 已提交
1243 1244 1245 1246 1247
        def _remove_if_exist(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

1248 1249
        if isinstance(getattr(type(self), name, None), property):
            object.__setattr__(self, name, value)
1250
        params = self.__dict__.get('_parameters', None)
X
Xin Pan 已提交
1251 1252
        if isinstance(value, framework.Parameter):
            if params is None:
1253
                raise ValueError("super().__init__() should be called first")
H
hong 已提交
1254
            if len(self._loaddict_holder) > 0:
1255 1256 1257 1258 1259
                assert (
                    value.name in self._loaddict_holder
                ), "Parameter not found, Can't not find [ {} ] in state_dict".format(
                    value.name
                )
H
hong 已提交
1260 1261 1262

                value.set_value(self._loaddict_holder[value.name])

1263
            _remove_if_exist(self.__dict__, self._buffers, self._sub_layers)
1264
            params[name] = value
1265 1266 1267
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(
1268 1269 1270 1271
                    "assignment to parameter '{}' should be of type Parameter or None, but got '{}'".format(
                        name, type(value).__name__
                    )
                )
1272
            params[name] = None
X
Xin Pan 已提交
1273
        else:
1274
            layers = self.__dict__.get('_sub_layers', None)
J
Jiabin Yang 已提交
1275
            if isinstance(value, Layer):
1276 1277
                if layers is None:
                    raise ValueError(
1278
                        "super().__init__() should be called first"
1279 1280
                    )

1281
                _remove_if_exist(self.__dict__, self._parameters, self._buffers)
1282 1283 1284 1285
                layers[name] = value
            elif layers is not None and name in layers:
                if value is not None:
                    raise TypeError(
1286 1287 1288 1289
                        "assignment to sublayer '{}' should be of type Layer or None, but got '{}'".format(
                            name, type(value).__name__
                        )
                    )
1290 1291
                layers[name] = None
            else:
1292
                _buffers = self.__dict__.get('_buffers', None)
W
wanghuancoder 已提交
1293
                if isinstance(value, (core.VarBase, core.eager.Tensor)):
1294 1295
                    if _buffers is None:
                        raise ValueError(
1296
                            "super().__init__() should be called first"
1297
                        )
1298 1299 1300
                    _remove_if_exist(
                        self.__dict__, self._parameters, self._sub_layers
                    )
1301 1302 1303 1304
                    # Set persistable=False by default. Only `register_buffer` can
                    # add a persistable buffer.
                    if name not in self._buffers:
                        self._non_persistable_buffer_names_set.add(name)
1305 1306
                    if not value.name:
                        value.name = unique_name.generate('_buffers_' + name)
1307 1308
                    _buffers[name] = value
                elif _buffers is not None and name in _buffers:
1309
                    # Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in
1310 1311 1312 1313
                    # decorated function, such as `self.buffer = new_tensor`. So we update its
                    # value via `assign`.
                    if type(value) == framework.Variable:
                        from paddle import assign
1314

1315 1316 1317 1318
                        # Note(zhhsplendid): the condition below happens in PaddleGan model,
                        # but should all non-Variable _buffers[name] be re-assign? We
                        # should consider it in the future. I current wrote this as
                        # conservative code.
1319 1320 1321
                        if in_declarative_mode() and _buffers[name] is None:
                            raise RuntimeError(
                                'In Dy2stat, self.{0} is a buffer and self.{0} is '
1322 1323 1324 1325 1326 1327 1328 1329
                                'not allowed to be set to Variable when self.{0} is None.'.format(
                                    name
                                )
                            )
                        elif (
                            _buffers[name] is None
                            or type(getattr(self, name)) == core.VarBase
                        ):
1330 1331
                            _buffers[name] = assign(value)
                        else:
1332
                            assign(value, getattr(self, name))
1333
                    elif value is not None:
1334
                        raise TypeError(
1335 1336 1337 1338
                            "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'".format(
                                name, type(value).__name__
                            )
                        )
1339 1340 1341 1342
                    else:
                        # Assigning None will remove the buffer, but if re-assign a new varBase to it,
                        # it will be remarked as a buffer with same `persistable` attribute.
                        _buffers[name] = None
1343 1344
                else:
                    object.__setattr__(self, name, value)
X
Xin Pan 已提交
1345 1346 1347 1348 1349 1350

    def __delattr__(self, name):
        if name in self._parameters:
            del self._parameters[name]
        elif name in self._sub_layers:
            del self._sub_layers[name]
1351 1352 1353
        elif name in self._buffers:
            del self._buffers[name]
            self._non_persistable_buffer_names_set.discard(name)
X
Xin Pan 已提交
1354 1355 1356
        else:
            object.__delattr__(self, name)

1357 1358
    def __dir__(self):
        """
W
wanghuancoder 已提交
1359
        Return a list. Get all parameters, buffers(non-parameter tensors), sublayers, method and attr of Layer.
1360 1361

        Examples:
1362 1363 1364
            .. code-block:: python
                import paddle
                import numpy as np
1365

1366 1367
                class Mylayer(paddle.nn.Layer):
                    def __init__(self):
1368
                        super().__init__()
1369 1370
                        self.linear1 = paddle.nn.Linear(10, 10)
                        self.linear2 = paddle.nn.Linear(5, 5)
C
cnn 已提交
1371
                        self.conv2d = paddle.nn.Conv2D(3, 2, 3)
1372 1373
                        self.embedding = paddle.nn.Embedding(128, 16)
                        self.h_0 = paddle.to_tensor(np.zeros([10, 10]).astype('float32'))
1374

1375 1376 1377 1378
                mylayer = Mylayer()
                print(dir(mylayer))
                # only parts are shown, because of list have too much content
                # ['__call__', '__class__',  ... , 'conv2d', 'embedding', 'h_0', 'linear1', 'linear2', ... , 'sublayers', 'train']
1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390

        """
        method = dir(self.__class__)
        attrs = list(self.__dict__.keys())
        parameters = list(self._parameters.keys())
        sublayers = list(self._sub_layers.keys())
        buffers = list(self._buffers.keys())

        keys = method + attrs + parameters + sublayers + buffers

        return keys

1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
    def extra_repr(self):
        """
        Extra representation of this layer, you can have custom implementation
        of your own layer.
        """
        return ''

    def __repr__(self):
        extra_lines = []
        extra_repr = self.extra_repr()
        extra_lines = extra_repr.split('\n')
        sublayer_lines = []
        for name, layer in self._sub_layers.items():
            sublayer_str = repr(layer)
            sublayer_str = _addindent(sublayer_str, 2)
            sublayer_lines.append('(' + name + '): ' + sublayer_str)

        final_str = self.__class__.__name__ + '('
        if extra_lines:
            if len(extra_lines) > 1:
                final_str += '\n  ' + '\n  '.join(extra_lines) + '\n'
            elif len(extra_lines) == 1:
                final_str += extra_lines[0]
        if sublayer_lines:
            final_str += '\n  ' + '\n  '.join(sublayer_lines) + '\n'

        final_str += ')'
        return final_str

1420 1421 1422 1423 1424
    def register_state_dict_hook(self, hook):
        hook_remove_helper = HookRemoveHelper(self._state_dict_hooks)
        self._state_dict_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

1425 1426 1427 1428 1429 1430
    def _obtain_parameters_buffers(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
    ):
S
ShenLiang 已提交
1431
        """
1432
        The difference from state_dict() is that state_dict_hook will not be called,
S
ShenLiang 已提交
1433 1434 1435 1436 1437 1438 1439 1440
        but the original types of parameters and buffers will be maintained.
        """
        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
                destination[structured_name_prefix + name] = data
        for name, buffer in self._buffers.items():
1441 1442 1443 1444
            if (
                buffer is not None
                and name not in self._non_persistable_buffer_names_set
            ):
S
ShenLiang 已提交
1445 1446 1447 1448 1449 1450 1451 1452
                destination[structured_name_prefix + name] = buffer

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
                        layer_item._obtain_parameters_buffers(
1453 1454 1455 1456 1457
                            destination_temp,
                            include_sublayers,
                            structured_name_prefix + layer_name + ".",
                        )
                    )
S
ShenLiang 已提交
1458 1459 1460
                    destination = destination_temp
        return destination

1461 1462 1463 1464 1465 1466 1467 1468
    def _state_dict_impl(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        include_non_persistable_buffer=False,
        use_hook=True,
    ):
1469 1470 1471 1472 1473 1474 1475
        """
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
            include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False
1476
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1477 1478 1479 1480 1481 1482 1483 1484 1485
        """

        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
                destination[structured_name_prefix + name] = data
        for name, buffer in self._buffers.items():
            if not include_non_persistable_buffer:
1486 1487 1488 1489
                if (
                    buffer is not None
                    and name not in self._non_persistable_buffer_names_set
                ):
1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500
                    destination[structured_name_prefix + name] = buffer
            else:
                if buffer is not None:
                    destination[structured_name_prefix + name] = buffer

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
                        layer_item._state_dict_impl(
1501 1502
                            destination_temp,
                            include_sublayers,
1503
                            structured_name_prefix + layer_name + ".",
1504 1505 1506 1507
                            include_non_persistable_buffer,
                            use_hook,
                        )
                    )
1508
                    destination = destination_temp
1509 1510 1511 1512 1513
        if use_hook:
            for state_dict_hook in self._state_dict_hooks.values():
                hook_result = state_dict_hook(destination)
                if hook_result is not None:
                    destination = hook_result
1514 1515 1516

        return destination

1517 1518 1519 1520 1521 1522 1523
    def to_static_state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        use_hook=True,
    ):
1524
        '''
U
ustiniankw 已提交
1525

1526 1527 1528 1529 1530
        Get all parameters and buffers of current layer and its sub-layers. And set them into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
1531
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1532

1533
        Retruns:
U
ustiniankw 已提交
1534
            dict, a dict contains all the parameters and persistable buffers.
1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550

        Examples:
            .. code-block:: python

                import paddle

                emb = paddle.nn.Embedding(10, 10)

                state_dict = emb.to_static_state_dict()
                paddle.save( state_dict, "paddle_dy.pdparams")

        '''
        return self._state_dict_impl(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
1551
            include_non_persistable_buffer=True,
1552 1553 1554 1555 1556 1557 1558 1559 1560 1561
            use_hook=use_hook,
        )

    def state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        use_hook=True,
    ):
H
hong 已提交
1562
        '''
1563
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
H
hong 已提交
1564

1565
        Parameters:
1566 1567
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
1568
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1569

H
hong 已提交
1570
        Retruns:
1571
            dict: a dict contains all the parameters and persistable buffers.
H
hong 已提交
1572 1573

        Examples:
1574 1575
            .. code-block:: python

1576
                import paddle
H
hong 已提交
1577

1578 1579 1580 1581
                emb = paddle.nn.Embedding(10, 10)

                state_dict = emb.state_dict()
                paddle.save( state_dict, "paddle_dy.pdparams")
H
hong 已提交
1582 1583

        '''
1584 1585 1586 1587
        return self._state_dict_impl(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
1588
            include_non_persistable_buffer=False,
1589 1590
            use_hook=use_hook,
        )
1591

1592
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
1593
    def set_state_dict(self, state_dict, use_structured_name=True):
H
hong 已提交
1594
        '''
1595
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
H
hong 已提交
1596

1597
        Parameters:
1598
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
1599
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
H
hong 已提交
1600
                                                  Default: True
H
hong 已提交
1601
        Returns:
1602 1603
            missing_keys(list):A list of str containing the missing keys
            unexpected_keys(list):A list of str containing the unexpected keys
H
hong 已提交
1604 1605

        Examples:
1606 1607
            .. code-block:: python

1608
                import paddle
1609

1610
                emb = paddle.nn.Embedding(10, 10)
H
hong 已提交
1611

1612
                state_dict = emb.state_dict()
1613 1614
                paddle.save(state_dict, "paddle_dy.pdparams")
                para_state_dict = paddle.load("paddle_dy.pdparams")
1615
                emb.set_state_dict(para_state_dict)
H
hong 已提交
1616

H
hong 已提交
1617
        '''
1618 1619 1620
        missing_keys = []
        match_keys = set()
        unexpected_keys = []
H
hong 已提交
1621

1622 1623 1624
        def _check_match(key, param):
            state = state_dict.get(key, None)
            if state is None:
1625
                missing_keys.append(key)
1626
                raise ValueError(
1627 1628 1629 1630
                    "{} is not found in the provided dict.".format(key)
                )
            if isinstance(state, dict) or isinstance(state, list):
                if len(state) != len(param):
1631
                    missing_keys.append(key)
1632 1633 1634 1635 1636 1637
                    raise ValueError(
                        "{} receieves the length of {}, "
                        "but the expected shape is {}".format(
                            key, len(state), len(param)
                        )
                    )
S
Steffy-zxf 已提交
1638
                else:
1639
                    match_keys.add(key)
S
Steffy-zxf 已提交
1640 1641
                    return param, state
            else:
1642 1643 1644 1645 1646
                state_shape = (
                    state.shape()
                    if inspect.ismethod(state.shape)
                    else state.shape
                )
S
Steffy-zxf 已提交
1647 1648

                if list(state_shape) != list(param.shape):
1649
                    missing_keys.append(key)
S
Steffy-zxf 已提交
1650
                    raise ValueError(
1651 1652 1653 1654
                        "{} receives a shape {}, but the expected shape is {}.".format(
                            key, list(state_shape), list(param.shape)
                        )
                    )
1655
                match_keys.add(key)
S
Steffy-zxf 已提交
1656
                return param, state
1657 1658

        matched_param_state = []
S
sneaxiy 已提交
1659
        for key, param in self._state_dict_impl(use_hook=False).items():
1660 1661 1662 1663 1664 1665
            key_name = key if use_structured_name else param.name
            try:
                match_res = _check_match(key_name, param)
                matched_param_state.append(match_res)
            except ValueError as err:
                warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
1666 1667 1668
        for key in state_dict.keys():
            if key not in match_keys:
                unexpected_keys.append(key)
姜永久 已提交
1669
        if in_dygraph_mode():
1670 1671 1672
            for param, state in matched_param_state:
                param.set_value(state)
        else:
H
hong 已提交
1673

1674 1675 1676 1677 1678 1679 1680
            def _set_var(var, ndarray):
                t = global_scope().find_var(var.name).get_tensor()
                p = t._place()
                if p.is_cpu_place():
                    place = core.CPUPlace()
                elif p.is_cuda_pinned_place():
                    place = core.CUDAPinnedPlace()
1681 1682 1683 1684
                elif p.is_xpu_place():
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.XPUPlace(p.xpu_device_id())
1685 1686 1687 1688 1689 1690
                else:
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.CUDAPlace(p.gpu_device_id())
                t.set(ndarray, place)

1691 1692 1693 1694 1695
            try:
                executor = Executor(_get_device())._default_executor
                # restore parameter states
                core._create_loaded_parameter(
                    [param for param, state in matched_param_state],
1696 1697 1698
                    global_scope(),
                    executor,
                )
1699 1700 1701 1702 1703 1704
                for param, state in matched_param_state:
                    _set_var(param, state)
            except ValueError as e:
                raise ValueError(
                    "This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
                )
1705

1706 1707
        return missing_keys, unexpected_keys

C
chentianyu03 已提交
1708 1709 1710 1711 1712
    def to(self, device=None, dtype=None, blocking=None):
        '''
        Cast the parameters and buffers of Layer by the give device, dtype and blocking.

        Parameters:
1713 1714 1715 1716
            device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
            If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None.

1717
            dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
C
chentianyu03 已提交
1718

1719
            blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
C
chentianyu03 已提交
1720
              asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
1721

C
chentianyu03 已提交
1722
        Returns:
1723
            self
C
chentianyu03 已提交
1724 1725 1726 1727

        Examples:
            .. code-block:: python

1728
                # required: skip
C
chentianyu03 已提交
1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753
                import paddle

                linear=paddle.nn.Linear(2, 2)
                linear.weight
                #Parameter containing:
                #Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(dtype='float64')
                linear.weight
                #Tenor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(device='cpu')
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CPUPlace, stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])
                linear.to(device=paddle.CUDAPinnedPlace(), blocking=False)
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CUDAPinnedPlace, stop_gradient=False,
                #       [[-0.04989364, -0.56889004],
                #        [ 0.33960250,  0.96878713]])
1754

1755
        '''
1756 1757 1758 1759 1760 1761 1762
        return self._to_impl(
            device=device,
            dtype=dtype,
            blocking=blocking,
            include_sublayers=True,
            floating_only=False,
        )
1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775

    def _apply(self, func, device, dtype, blocking, include_sublayers=True):
        if include_sublayers:
            for layer in self.children():
                layer._apply(func, device, dtype, blocking, include_sublayers)

        for key, param in self._parameters.items():
            if param is not None:
                with no_grad():
                    param_applied = func(param, device, dtype, blocking)

                if param.grad is not None:
                    with no_grad():
1776 1777 1778
                        grad_applied = func(
                            param._grad_ivar(), device, dtype, blocking
                        )
1779 1780

        for key, buf in self._buffers.items():
1781 1782
            if buf is not None:
                self._buffers[key] = func(buf, device, dtype, blocking)
1783

1784 1785
        self._dtype = dtype

1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801
    def _transform(self, t, device, dtype, blocking):
        if device is None:
            device = t.place
        if dtype is None:
            dtype = t.dtype

        if type(dtype) is not VarDesc.VarType:
            dtype = convert_np_dtype_to_dtype_(dtype)

        # 1. gpu place need to determine whether the memory is sufficient for allocation:
        if t.place.is_gpu_place():
            # for gpu, minimum memory allocation unit is 256 bytes.
            size_dtype = core.size_of_dtype(dtype)
            # Note(zhangbo): Paddle GPU minimum memory allocation unit is 256 bytes, waiting_alloc_memory will comput ‘t’ occupied memory space.
            # Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
            waiting_alloc_memory = (
1802 1803
                ((np.prod(t.shape) * size_dtype) / 256 + 1) * 256 * 1.2
            )
1804 1805 1806
            gpu_memory_available = core.gpu_memory_available()
            if gpu_memory_available < waiting_alloc_memory:
                # Copy param / Tensor to cpu
1807 1808 1809
                t_used = t._copy_to(
                    paddle.CPUPlace(), blocking
                )  # k-v type will error
1810 1811 1812 1813 1814 1815 1816 1817 1818 1819
                # Release mem of t
                t.value().get_tensor()._clear()
            else:
                t_used = t
        else:
            t_used = t

        # 2. cast param / Tensor to dtype
        if dtype is not None and dtype != t_used.dtype:
            with paddle.fluid.framework._dygraph_place_guard(
1820 1821
                place=t_used.place
            ):
1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838
                t_casted = t_used.cast(dtype=dtype)
        else:
            t_casted = t_used

        # 3. Copy casted cpu param / Tensor to device
        if device is not None and not t_casted.place._equals(device):
            new_t = t_casted._copy_to(device, blocking)
        else:
            new_t = t_casted

        # 4. share Tensor to origin param / Tensor
        dst_tensor = t.value().get_tensor()
        src_tensor = new_t.value().get_tensor()
        dst_tensor._share_data_with(src_tensor)

        return t

1839 1840 1841 1842 1843 1844 1845 1846
    def _to_impl(
        self,
        device=None,
        dtype=None,
        blocking=None,
        include_sublayers=True,
        floating_only=False,
    ):
1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858
        '''
        Cast the parameters and buffers of Layer by the give device, dtype and blocking.

        Parameters:
            device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
            If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None.

            dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.

            blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
              asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
1859

1860 1861
            include_sublayers(bool|True, optional): If True, deal with self and all sublayers parameters and buffers, if not only deal with self parameters and buffers. Default: True.

1862 1863
            floating_only(bool|False, optional): If True, only cast all floating point parameters and buffers of Layer by the give device, dtype and blocking.

1864 1865
        Returns:
            self
C
chentianyu03 已提交
1866 1867 1868 1869

        '''

        if device is None and dtype is None and blocking is None:
1870
            return self
C
chentianyu03 已提交
1871 1872 1873 1874

        if device is not None:
            if isinstance(device, str):
                device = paddle.device._convert_to_place(device)
1875 1876 1877 1878 1879 1880 1881 1882 1883
            elif isinstance(
                device,
                (
                    core.CPUPlace,
                    core.CUDAPlace,
                    core.CUDAPinnedPlace,
                    core.XPUPlace,
                ),
            ):
C
chentianyu03 已提交
1884 1885 1886 1887
                pass
            else:
                raise ValueError(
                    "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
1888 1889
                    + type(device).__name__
                )
C
chentianyu03 已提交
1890 1891 1892 1893 1894

        if blocking is None:
            blocking = True
        else:
            assert isinstance(
1895 1896
                blocking, bool
            ), "blocking value error, must be the True, False or None"
C
chentianyu03 已提交
1897 1898

        def transform(t, device, dtype, blocking):
1899 1900 1901
            if floating_only and (not paddle.is_floating_point(t)):
                return t
            return self._transform(t, device, dtype, blocking)
C
chentianyu03 已提交
1902

1903 1904
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
1905
            self._apply(transform, device, dtype, blocking, include_sublayers)
1906

1907
        self._dtype = dtype
1908
        return self
C
chentianyu03 已提交
1909

1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921
    def _startup_program(self):
        """
        Return starup program containing initialization operations of all parameters.

        NOTE(dev): This is a very low level API and only for inner developer.
        """
        startup_program = Program()
        for param in self.parameters():
            param._create_init_op(startup_program.global_block())

        return startup_program

1922 1923 1924
    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict