layers.py 69.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
Xin Pan 已提交
15
import collections
16 17 18
import contextlib
import sys
import numpy as np
19
import re
20 21 22
import copy
import weakref
import warnings
23
from copy import deepcopy
24 25
import inspect

26
import paddle
C
chenjian 已提交
27
import paddle.profiler as profiler
28
from paddle.profiler.utils import in_profiler_mode
29

C
chengduo 已提交
30
from . import parallel_helper
X
Xin Pan 已提交
31
from .. import unique_name
32
from paddle.fluid import core
33
from .layer_object_helper import LayerObjectHelper
34 35 36 37 38 39 40 41 42 43 44
from .layer_hooks import (
    record_program_ops_pre_hook,
    set_op_customized_attrs_post_hook,
    LayerOpsRecoder,
)
from .base import (
    program_desc_tracing_guard,
    param_guard,
    in_declarative_mode,
    _convert_into_variable,
)
45
from paddle.fluid import framework
46
from ..param_attr import ParamAttr
47
from paddle.fluid.executor import Executor, global_scope
48 49 50 51 52
from paddle.fluid.framework import (
    _non_static_mode,
    convert_np_dtype_to_dtype_,
    in_dygraph_mode,
)
53
from paddle.fluid.framework import Program, program_guard
54
from paddle.fluid.framework import _current_expected_place as _get_device
55
from paddle.fluid.core import VarDesc
C
chentianyu03 已提交
56
from paddle.fluid.dygraph import no_grad
W
wanghuancoder 已提交
57
import paddle.utils.deprecated as deprecated
58

59
__all__ = ['Layer']
60

61 62 63 64
_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re = re.compile('([a-z])([A-Z])')


65 66 67 68 69 70 71 72 73 74 75
def _scope_dist2single(dist_scope):
    mapping = {
        "row_parallel_linear": "linear",
        "column_parallel_linear": "linear",
        "vocab_parallel_embedding": "embedding",
        # "parallel_cross_entropy": "cross_entropy", while mp_layer has parallel_cross_entropy,
        # but there is no parameters so the mapping of parallel_cross_entropy is not neccessary.
    }
    return mapping.get(dist_scope, dist_scope)


76 77 78 79
def _convert_camel_to_snake(name):
    s1 = _first_cap_re.sub(r'\1_\2', name)
    return _all_cap_re.sub(r'\1_\2', s1).lower()

80

81 82 83 84 85 86 87 88 89 90 91
def _addindent(string, indent):
    s1 = string.split('\n')
    if len(s1) == 1:
        return string
    s2 = []
    for idx, line in enumerate(s1):
        if idx > 0:
            s2.append(str((indent * ' ') + line))
    return s1[0] + '\n' + '\n'.join(s2)


92
class HookRemoveHelper:
93
    """A HookRemoveHelper that can be used to remove hook."""
94 95 96 97 98 99 100 101 102 103 104 105 106 107

    next_hook_id = 0

    def __init__(self, hooks):
        self._hooks_ref = weakref.ref(hooks)
        self._hook_id = HookRemoveHelper.next_hook_id
        HookRemoveHelper.next_hook_id += 1

    def remove(self):
        hooks = self._hooks_ref()
        if hooks is not None and self._hook_id in hooks:
            del hooks[self._hook_id]


108
class Layer:
109 110
    """
    Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
X
Xin Pan 已提交
111

112
    Parameters:
113 114
        name_scope (str, optional): prefix name used by the layer to name parameters.
            If prefix is "my_layer", parameter name in MyLayer
115 116 117
            can be "my_layer_0.w_n", where "w" is the parameter
            base name and "n" is an unique suffix auto-generated.
            If None, prefix name will be snake cased class name. Default: None.
118
        dtype(str, optional): data type of this parameter.
119 120
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
121
                Default: "float32"
122

123 124
    Returns:
        None
125 126 127 128 129 130 131

    Examples:
        .. code-block:: python

            import paddle
            class MyLayer(paddle.nn.Layer):
                def __init__(self):
132
                    super().__init__()
133 134 135 136 137 138 139 140 141 142 143 144
                    self._linear = paddle.nn.Linear(1, 1)
                    self._dropout = paddle.nn.Dropout(p=0.5)
                def forward(self, input):
                    temp = self._linear(input)
                    temp = self._dropout(temp)
                    return temp
            x = paddle.randn([10, 1], 'float32')
            mylayer = MyLayer()
            mylayer.eval()  # set mylayer._dropout to eval mode
            out = mylayer(x)
            mylayer.train()  # set mylayer._dropout to train mode
            out = mylayer(x)
X
Xin Pan 已提交
145
    """
X
Xin Pan 已提交
146

147
    def __init__(self, name_scope=None, dtype="float32"):
148
        self.training = True
149
        if name_scope is None:
150
            name_scope = _convert_camel_to_snake(self.__class__.__name__)
151
            name_scope = _scope_dist2single(name_scope)
152
        self._full_name = unique_name.generate(name_scope)
153
        self._helper = LayerObjectHelper(self._full_name)
X
Xin Pan 已提交
154
        self._built = False
M
minqiyang 已提交
155
        self._dtype = dtype
J
Jiabin Yang 已提交
156
        self._init_in_dynamic_mode = framework._non_static_mode()
157

X
Xin Pan 已提交
158
        self._parameters = collections.OrderedDict()
159 160 161
        # Buffers the variable (not parameter) created in layer
        self._buffers = collections.OrderedDict()
        self._non_persistable_buffer_names_set = set()
X
Xin Pan 已提交
162
        self._sub_layers = collections.OrderedDict()
L
lujun 已提交
163
        self._loaddict_holder = collections.OrderedDict()
164

165 166 167 168
        # Record generated op_descs in this layer
        self._op_recorder = LayerOpsRecoder(ops=[], hooks=[])
        self._customized_attrs = {}

169 170 171
        self._forward_pre_hooks = collections.OrderedDict()
        self._forward_post_hooks = collections.OrderedDict()

172 173 174
        self._casted_by_pure_fp16 = False

        self._state_dict_hooks = collections.OrderedDict()
175 176
        # Records orignal functions after @to_static to support to rollback
        self._original_funcs = collections.OrderedDict()
177

M
minqiyang 已提交
178
    def train(self):
179 180 181 182 183 184
        """
        Sets this Layer and all its sublayers to training mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
185 186 187 188 189 190 191 192

        Example::
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
193
                        super().__init__()
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                mylayer.train()  # set mylayer._dropout to train mode
                out = mylayer(x)

209
        """
210 211 212
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
J
Jiabin Yang 已提交
213
        if _non_static_mode():
214
            framework._dygraph_tracer().train_mode()
215 216 217
        # Layer-level setting
        self.training = True
        for layer in self.sublayers():
218
            layer.training = True
M
minqiyang 已提交
219 220

    def eval(self):
221 222 223 224 225 226
        """
        Sets this Layer and all its sublayers to evaluation mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
227 228 229 230 231 232 233 234

        Example::
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
235
                        super().__init__()
236 237 238 239 240 241 242 243 244 245 246 247 248 249
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                print(out)

250
        """
251 252 253
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
J
Jiabin Yang 已提交
254
        if _non_static_mode():
255
            framework._dygraph_tracer().eval_mode()
256 257 258
        # Layer-level setting
        self.training = False
        for layer in self.sublayers():
259
            layer.training = False
M
minqiyang 已提交
260

L
LielinJiang 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
    def apply(self, fn):
        """
        Applies ``fn`` recursively to every sublayer (as returned by ``.sublayers()``)
        as well as self. Typical use includes initializing the parameters of a model.

        Parameters:
            fn (function): a function to be applied to each sublayer

        Returns:
            Layer: self

        Example::
            .. code-block:: python

              import paddle
              import paddle.nn as nn
277

L
LielinJiang 已提交
278 279 280 281 282
              net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

              def init_weights(layer):
                  if type(layer) == nn.Linear:
                      print('before init weight:', layer.weight.numpy())
283
                      new_weight = paddle.full(shape=layer.weight.shape, dtype=layer.weight.dtype, fill_value=0.9)
L
LielinJiang 已提交
284 285 286 287 288 289 290
                      layer.weight.set_value(new_weight)
                      print('after init weight:', layer.weight.numpy())

              net.apply(init_weights)

              print(net.state_dict())
        """
291
        for layer in self.children():
L
LielinJiang 已提交
292 293 294 295 296 297
            layer.apply(fn)

        fn(self)

        return self

X
Xin Pan 已提交
298
    def full_name(self):
299
        """Full name for this layer, composed by name_scope + "/" + MyLayer.__class__.__name__
X
Xin Pan 已提交
300

301 302
        Returns:
            str: full name of this layer.
303 304 305 306 307 308 309 310

        Example::
            .. code-block:: python

                import paddle

                class LinearNet(paddle.nn.Layer):
                    def __init__(self):
311
                        super().__init__(name_scope = "demo_linear_net")
312 313 314 315 316 317 318 319
                        self._linear = paddle.nn.Linear(1, 1)

                    def forward(self, x):
                        return self._linear(x)

                linear_net = LinearNet()
                print(linear_net.full_name())   # demo_linear_net_0

X
Xin Pan 已提交
320 321 322
        """
        return self._full_name

323 324 325 326 327
    def register_forward_post_hook(self, hook):
        """Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed.

        It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively.
        User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.
328

329 330 331 332 333 334 335 336 337 338 339
        hook(Layer, input, output) -> None or modified output

        Parameters:
            hook(function): a function registered as a forward post-hook

        Returns:
            HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .

        Examples:
            .. code-block:: python

340 341 342 343 344 345
                import paddle
                import numpy as np

                # the forward_post_hook change the output of the layer: output = output * 2
                def forward_post_hook(layer, input, output):
                    # user can use layer, input and output for information statistis tasks
346

347 348
                    # change the output
                    return output * 2
349

350
                linear = paddle.nn.Linear(13, 5)
351

352 353
                # register the hook
                forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
354

355 356
                value1 = np.arange(26).reshape(2, 13).astype("float32")
                in1 = paddle.to_tensor(value1)
357

358
                out0 = linear(in1)
359

360 361 362 363 364 365 366
                # remove the hook
                forward_post_hook_handle.remove()

                out1 = linear(in1)

                # hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
                assert (out0.numpy() == (out1.numpy()) * 2).any()
367 368 369 370 371 372 373
        """
        hook_remove_helper = HookRemoveHelper(self._forward_post_hooks)
        self._forward_post_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

    def register_forward_pre_hook(self, hook):
        """Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed.
374

375
        It should have the following form, `input` of the `hook` is `input` of the `Layer`,
376
        hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if
377 378 379 380 381 382 383 384 385 386 387 388 389 390
        a single value is returned(unless that value is already a tuple).
        User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.

        hook(Layer, input) -> None or modified input

        Parameters:
            hook(function): a function registered as a forward pre-hook

        Returns:
            HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .

        Examples:
            .. code-block:: python

391 392
                import paddle
                import numpy as np
393

394
                # the forward_pre_hook change the input of the layer: input = input * 2
395 396
                def forward_pre_hook(layer, input):
                    # user can use layer and input for information statistis tasks
397

398 399 400
                    # change the input
                    input_return = (input[0] * 2)
                    return input_return
401

402
                linear = paddle.nn.Linear(13, 5)
403

404 405
                # register the hook
                forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)
406

407 408 409
                value0 = np.arange(26).reshape(2, 13).astype("float32")
                in0 = paddle.to_tensor(value0)
                out0 = linear(in0)
410

411 412
                # remove the hook
                forward_pre_hook_handle.remove()
413

414 415 416
                value1 = value0 * 2
                in1 = paddle.to_tensor(value1)
                out1 = linear(in1)
417

418 419
                # hook change the linear's input to input * 2, so out0 is equal to out1.
                assert (out0.numpy() == out1.numpy()).any()
420 421 422 423 424
        """
        hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
        self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

425 426 427 428 429 430 431 432
    def create_parameter(
        self,
        shape,
        attr=None,
        dtype=None,
        is_bias=False,
        default_initializer=None,
    ):
433
        """Create parameters for this layer.
434

435
        Parameters:
436
            shape(list): Shape of the parameter.
437 438
            attr(ParamAttr, optional): Parameter attribute of weight. Please refer to :ref:`api_paddle_ParamAttr`. Default: None.
            dtype(str, optional): Data type of this parameter.
439
                If set str, it can be "bool",  "float16", "float32", "float64",
440 441
                "int8", "int16", "int32", "int64", "uint8" or "uint16". Default: "float32".
            is_bias(bool, optional): if this is a bias parameter. Default: False.
442
            default_initializer(Initializer, optional): the default initializer for this parameter.
443
                If set None, default initializer will be set to paddle.nn.initializer.Xavier and paddle.nn.initializer.Constant
444
                for non-bias and bias parameter, respectively. Default: None.
445

446
        Returns:
447 448 449 450 451 452 453 454 455
            :Tensor, created parameter.

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
456
                        super().__init__()
457 458 459 460 461 462 463 464 465 466 467
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

468
        """
H
hong 已提交
469
        temp_attr = copy.deepcopy(attr)
470
        if isinstance(temp_attr, str) and temp_attr == "":
H
hong 已提交
471
            temp_attr = None
472 473 474 475 476 477 478 479 480
        return self._helper.create_parameter(
            temp_attr, shape, dtype, is_bias, default_initializer
        )

    @deprecated(
        since="2.0.0",
        update_to="paddle.nn.Layer.create_tensor",
        reason="New api in create_tensor, easier to use.",
    )
481
    def create_variable(self, name=None, persistable=None, dtype=None):
W
wanghuancoder 已提交
482 483 484
        """

        Create Tensor for this layer.
485

486
        Parameters:
W
wanghuancoder 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None

            persistable(bool, optional): if set this tensor persistable. Default: False

            dtype(str, optional): data type of this parameter. If set str, it can be "bool", "float16", "float32", "float64","int8", "int16", "int32", "int64", "uint8" or "uint16". If set None, it will be "float32". Default: None

        Returns:
            Tensor, created Tensor.

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
505
                        super().__init__()
W
wanghuancoder 已提交
506
                        self.linear = paddle.nn.Linear( 10, 10)
507

W
wanghuancoder 已提交
508
                        self.back_var = self.create_variable(name = "linear_tmp_0", dtype=self._dtype)
509

W
wanghuancoder 已提交
510 511 512
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
513

W
wanghuancoder 已提交
514 515 516 517 518 519
                        return out

        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
520 521 522
            var_name = unique_name.generate(
                ".".join([self._full_name, "_generated_var"])
            )
W
wanghuancoder 已提交
523 524 525 526 527

        return self._helper.main_program.current_block().create_var(
            name=var_name,
            persistable=persistable,
            dtype=dtype,
528 529
            type=core.VarDesc.VarType.LOD_TENSOR,
        )
W
wanghuancoder 已提交
530 531 532 533 534 535 536 537 538 539

    # TODO: Add more parameter list when we need them
    def create_tensor(self, name=None, persistable=None, dtype=None):
        """

        Create Tensor for this layer.

        Parameters:
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None
            persistable(bool, optional): if set this tensor persistable. Default: False
540
            dtype(str, optional): data type of this parameter.
541 542
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
543
                If set None, it will be "float32". Default: None
544

545
        Returns:
W
wanghuancoder 已提交
546
            Tensor, created Tensor.
547 548 549 550 551 552 553 554 555 556

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
557
                        super().__init__()
558
                        self.linear = paddle.nn.Linear( 10, 10)
559

W
wanghuancoder 已提交
560
                        self.back_var = self.create_tensor(name = "linear_tmp_0", dtype=self._dtype)
561

562 563 564
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
565

566 567
                        return out

568 569 570 571
        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
572 573 574
            var_name = unique_name.generate(
                ".".join([self._full_name, "_generated_var"])
            )
575 576

        return self._helper.main_program.current_block().create_var(
577 578 579
            name=var_name,
            persistable=persistable,
            dtype=dtype,
580 581
            type=core.VarDesc.VarType.LOD_TENSOR,
        )
582

X
polish  
Xin Pan 已提交
583
    def parameters(self, include_sublayers=True):
584
        """Returns a list of all Parameters from current layer and its sub-layers.
X
Xin Pan 已提交
585

586
        Returns:
587 588 589 590 591 592 593 594 595 596
            list of Tensor : a list of Parameters.

        Examples:
            .. code-block:: python

            import paddle

            linear = paddle.nn.Linear(1,1)
            print(linear.parameters())  # print linear_0.w_0 and linear_0.b_0

X
Xin Pan 已提交
597
        """
598
        ret = [
599 600 601 602
            param
            for _, param in self.named_parameters(
                include_sublayers=include_sublayers
            )
603
        ]
X
polish  
Xin Pan 已提交
604
        return ret
X
Xin Pan 已提交
605

606 607 608 609 610 611 612 613 614
    def children(self):
        """Returns an iterator over immediate children layers.

        Yields:
            Layer: a child layer

        Examples:
            .. code-block:: python

615
                import paddle
616

617 618 619 620 621
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)

                layer_list = list(model.children())
622

623
                print(layer_list)   # [<paddle.nn.layer.common.Linear object at 0x7f7b8113f830>, <paddle.nn.layer.common.Linear object at 0x7f7b8113f950>]
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638

        """
        for _, layer in self.named_children():
            yield layer

    def named_children(self):
        """Returns an iterator over immediate children layers, yielding both
        the name of the layer as well as the layer itself.

        Yields:
            (string, Layer): Tuple containing a name and child layer

        Examples:
            .. code-block:: python

639
                import paddle
640

641 642 643 644 645 646 647
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)
                for prefix, layer in model.named_children():
                    print(prefix, layer)
                    # ('0', <paddle.nn.layer.common.Linear object at 0x7fb61ed85830>)
                    # ('1', <paddle.nn.layer.common.Linear object at 0x7fb61ed85950>)
648 649 650 651 652 653 654 655

        """
        memo = set()
        for name, layer in self._sub_layers.items():
            if layer is not None and layer not in memo:
                memo.add(layer)
                yield name, layer

J
Jiabin Yang 已提交
656
    def sublayers(self, include_self=False):
X
Xin Pan 已提交
657 658
        """Returns a list of sub layers.

659
        Parameters:
J
Jiabin Yang 已提交
660
            include_self(bool, optional): Whether return self as sublayers. Default: False
X
Xin Pan 已提交
661

662 663
        Returns:
            list of Layer : a list of sub layers.
664 665 666 667 668 669 670 671

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
672
                        super().__init__()
673 674 675 676 677 678 679 680 681 682 683
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                mylayer = MyLayer()
                print(mylayer.sublayers())  # [<paddle.nn.layer.common.Linear object at 0x7f44b58977d0>, <paddle.nn.layer.common.Dropout object at 0x7f44b58978f0>]

X
Xin Pan 已提交
684
        """
685 686
        ret = [
            layer
J
Jiabin Yang 已提交
687
            for _, layer in self.named_sublayers(include_self=include_self)
688
        ]
X
Xin Pan 已提交
689 690
        return ret

691 692 693 694 695 696 697 698 699 700 701 702 703 704 705
    def named_parameters(self, prefix='', include_sublayers=True):
        """
        Returns an iterator over all parameters in the Layer, yielding tuple of name and parameter.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_sublayers(bool, optional): Whether include the parameters of sublayers.
                If True, also include the named parameters from sublayers. Default: True.

        Yields:
            (string, Parameter): Tuple of name and Parameter

        Examples:
            .. code-block:: python

706
                import paddle
707

708 709 710 711 712
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for name, param in model.named_parameters():
                    print(name, param)
713 714 715

        """
        params_set = set()
716 717 718 719 720
        named_sublayers = (
            self.named_sublayers(prefix=prefix, include_self=True)
            if include_sublayers
            else zip([prefix], [self])
        )
721 722 723 724 725 726 727 728 729
        for layer_prefix, sublayer in named_sublayers:
            params = sublayer._parameters.items()
            for key, param in params:
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, param

J
Jiabin Yang 已提交
730
    def named_sublayers(self, prefix='', include_self=False, layers_set=None):
731 732 733 734 735 736 737
        """
        Returns an iterator over all sublayers in the Layer, yielding tuple of name and sublayer.
        The duplicate sublayer will only be yielded once.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_self(bool, optional): Whether include the Layer itself. Default: False.
738
            layers_set(set, optional): The set to record duplicate sublayers. Default: None.
739 740 741 742 743 744 745

        Yields:
            (string, Layer): Tuple of name and Layer

        Examples:
            .. code-block:: python

746
                import paddle
747

748 749 750 751 752
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
753 754 755 756 757 758 759

        """
        if layers_set is None:
            layers_set = set()
        if include_self and self not in layers_set:
            layers_set.add(self)
            yield prefix, self
J
Jiabin Yang 已提交
760 761 762 763
        for key, layer in self._sub_layers.items():
            if layer is None:
                continue
            layer_prefix = prefix + ('.' if prefix else '') + key
764 765 766
            for p, l in layer.named_sublayers(
                prefix=layer_prefix, include_self=True, layers_set=layers_set
            ):
J
Jiabin Yang 已提交
767
                yield p, l
768

769
    def register_buffer(self, name, tensor, persistable=True):
770
        """
771
        Registers a tensor as buffer into the layer.
772

773
        `buffer` is a non-trainable tensor and will not be updated by optimizer,
774 775 776 777 778 779 780 781 782 783
        but is necessary for evaluation and inference. For example, the mean and variance in BatchNorm layers.
        The registered buffer is persistable by default, and will be saved into
        `state_dict` alongside parameters. If set persistable=False, it registers
        a non-persistable buffer, so that it will not be a part of `state_dict` .

        Buffers can be accessed as attributes using given names.

        Parameters:
            name (string): name of the buffer. The buffer can be accessed
                from this layer using the given name
784
            tensor (Tensor): the tensor to be registered as buffer.
785 786 787 788 789
            persistable (bool): whether the buffer is part of this layer's
                state_dict.

        Returns:
            None
790

791 792 793 794
        Examples:
            .. code-block:: python

                import numpy as np
795
                import paddle
796

797 798 799 800 801 802 803
                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                # get the buffer by attribute.
                print(linear.buf_name)
804 805 806 807

        """

        if '_buffers' not in self.__dict__:
808
            raise ValueError("super().__init__() should be called first")
809
        elif not isinstance(name, str):
810
            raise TypeError(
811 812 813 814
                "The name of buffer should be a string, but received {}.".format(
                    type(name).__name__
                )
            )
815
        elif '.' in name:
816 817 818
            raise KeyError(
                "The name of buffer can not contain `.`, "
                "because when you access the newly added buffer in the "
819 820
                "form of `self.**.**`, it will cause AttributeError."
            )
821 822 823 824
        elif name == '':
            raise KeyError("The name of buffer can not be empty.")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists.".format(name))
825 826 827
        elif tensor is not None and not (
            type(tensor) == core.VarBase or type(tensor) == core.eager.Tensor
        ):
828
            raise TypeError(
829 830 831 832
                "The registered buffer should be a Paddle.Tensor, but received {}.".format(
                    type(tensor).__name__
                )
            )
833
        else:
834
            self._buffers[name] = tensor
835 836 837 838 839 840 841 842 843 844 845 846 847
            if persistable:
                self._non_persistable_buffer_names_set.discard(name)
            else:
                self._non_persistable_buffer_names_set.add(name)

    def buffers(self, include_sublayers=True):
        """
        Returns a list of all buffers from current layer and its sub-layers.

        Parameters:
            include_sublayers(bool, optional): Whether include the buffers of sublayers. If True, also include the buffers from sublayers. Default: True

        Returns:
848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
            list of Tensor : a list of buffers.

        Examples:
            .. code-block:: python

                import numpy as np
                import paddle

                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                print(linear.buffers())     # == print([linear.buf_name])

863 864
        """
        ret = [
865 866 867 868
            buffer
            for _, buffer in self.named_buffers(
                include_sublayers=include_sublayers
            )
869 870 871 872 873
        ]
        return ret

    def named_buffers(self, prefix='', include_sublayers=True):
        """
874
        Returns an iterator over all buffers in the Layer, yielding tuple of name and Tensor.
875 876 877 878 879 880 881

        Parameters:
            prefix(str, optional): Prefix to prepend to all buffer names. Default: ''.
            include_sublayers(bool, optional): Whether include the buffers of sublayers.
                If True, also include the named buffers from sublayers. Default: True.

        Yields:
882
            (string, Tensor): Tuple of name and tensor
883 884 885 886 887

        Examples:
            .. code-block:: python

                import numpy as np
888
                import paddle
889

890 891 892 893
                fc1 = paddle.nn.Linear(10, 3)
                buffer1 = paddle.to_tensor(np.array([0]).astype("float32"))
                # register a tensor as buffer by specific `persistable`
                fc1.register_buffer("buf_name_1", buffer1, persistable=True)
894

895 896 897 898 899
                fc2 = paddle.nn.Linear(3, 10)
                buffer2 = paddle.to_tensor(np.array([1]).astype("float32"))
                # register a buffer by assigning an attribute with Tensor.
                # The `persistable` can only be False by this way.
                fc2.buf_name_2 = buffer2
900

901
                model = paddle.nn.Sequential(fc1, fc2)
902

903 904 905
                # get all named buffers
                for name, buffer in model.named_buffers():
                    print(name, buffer)
906 907 908

        """
        buffers_set = set()
909 910 911 912 913
        named_sublayers = (
            self.named_sublayers(prefix=prefix, include_self=True)
            if include_sublayers
            else zip([prefix], [self])
        )
914 915 916 917 918 919 920 921 922
        for layer_prefix, sublayer in named_sublayers:
            buffers = sublayer._buffers.items()
            for key, buffer in buffers:
                if buffer is None or buffer in buffers_set:
                    continue
                buffers_set.add(buffer)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, buffer

X
Xin Pan 已提交
923
    def clear_gradients(self):
924 925
        """
        Clear the gradients of all parameters for this layer.
926

927 928
        Returns:
            None
929

930 931 932
        Examples:
            .. code-block:: python

933
                import paddle
934 935
                import numpy as np

936 937 938 939 940 941 942 943 944
                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.to_tensor(value)
                linear = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01,
                                            parameters=linear.parameters())
                out = linear(a)
                out.backward()
                adam.step()
                linear.clear_gradients()
945 946

        """
X
Xin Pan 已提交
947
        for p in self.parameters():
948 949
            if p.trainable:
                p.clear_gradient()
X
Xin Pan 已提交
950

951
    def _build_once(self, *args, **kwargs):
952 953
        pass

954 955 956 957 958
    def _dygraph_call_func(self, *inputs, **kwargs):
        for forward_pre_hook in self._forward_pre_hooks.values():
            hook_result = forward_pre_hook(self, inputs)
            if hook_result is not None:
                if not isinstance(hook_result, tuple):
959
                    hook_result = (hook_result,)
960 961 962 963 964 965 966 967 968
                inputs = hook_result

        if not self._built:
            with program_desc_tracing_guard(False):
                self._build_once(*inputs, **kwargs)

                # TODO(liuyuhui) Only xpu broadcast parameters here.
                # The other device is to call _sync_params_buffers in DataParallel
                # to realize the parameter synchronization among multiply cards.
969 970 971 972
                if (
                    parallel_helper._is_data_parallel_mode()
                    and paddle.is_compiled_with_xpu()
                ):
973
                    parallel_helper._broadcast_parameters(
974 975
                        self._parameters.values()
                    )
976 977 978

            self._built = True

979
        if in_profiler_mode():
980 981 982
            with profiler.RecordEvent(
                self.__class__.__name__, profiler.TracerEventType.Forward
            ):
983 984
                outputs = self.forward(*inputs, **kwargs)
        else:
C
chenjian 已提交
985
            outputs = self.forward(*inputs, **kwargs)
986 987 988 989 990 991 992 993

        for forward_post_hook in self._forward_post_hooks.values():
            hook_result = forward_post_hook(self, inputs, outputs)
            if hook_result is not None:
                outputs = hook_result

        return outputs

994
    def __call__(self, *inputs, **kwargs):
995 996 997 998 999 1000 1001 1002
        if (
            (not in_declarative_mode())
            and (not self._forward_pre_hooks)
            and (not self._forward_post_hooks)
            and (not self._built)
            and in_dygraph_mode()
            and (not in_profiler_mode())
        ):
1003 1004 1005 1006
            self._build_once(*inputs, **kwargs)
            return self.forward(*inputs, **kwargs)
        else:
            return self._dygraph_call_func(*inputs, **kwargs)
M
minqiyang 已提交
1007

1008
    def forward(self, *inputs, **kwargs):
1009 1010 1011 1012 1013 1014 1015 1016
        """
        Defines the computation performed at every call.
        Should be overridden by all subclasses.

        Parameters:
            *inputs(tuple): unpacked tuple arguments
            **kwargs(dict): unpacked dict arguments
        """
1017
        raise NotImplementedError
X
Xin Pan 已提交
1018 1019 1020 1021

    def backward(self, *inputs):
        raise ValueError("Layer shouldn't implement backward")

X
Xin Pan 已提交
1022 1023 1024
    def add_sublayer(self, name, sublayer):
        """Adds a sub Layer instance.

1025
        Added sublayer can be accessed by self.name
X
Xin Pan 已提交
1026

1027 1028 1029
        Parameters:
            name(str): name of this sublayer.
            sublayer(Layer): an instance of Layer.
X
Xin Pan 已提交
1030
        Returns:
1031
            Layer: the sublayer passed in.
1032

1033 1034 1035 1036 1037 1038 1039
        Examples:
            .. code-block:: python

                import paddle

                class MySequential(paddle.nn.Layer):
                    def __init__(self, *layers):
1040
                        super().__init__()
1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057
                        if len(layers) > 0 and isinstance(layers[0], tuple):
                            for name, layer in layers:
                                self.add_sublayer(name, layer)
                        else:
                            for idx, layer in enumerate(layers):
                                self.add_sublayer(str(idx), layer)

                    def forward(self, input):
                        for layer in self._sub_layers.values():
                            input = layer(input)
                        return input

                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = MySequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
X
Xin Pan 已提交
1058
        """
1059
        assert isinstance(sublayer, Layer) or sublayer is None
1060

X
Xin Pan 已提交
1061 1062 1063 1064 1065 1066
        self._sub_layers[name] = sublayer
        return sublayer

    def add_parameter(self, name, parameter):
        """Adds a Parameter instance.

1067
        Added parameter can be accessed by self.name
X
Xin Pan 已提交
1068

1069 1070 1071
        Parameters:
            name(str): name of this sublayer.
            parameter(Parameter): an instance of Parameter.
X
Xin Pan 已提交
1072
        Returns:
1073
            Parameter: the parameter passed in.
1074 1075 1076 1077 1078 1079 1080
        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
1081
                        super().__init__()
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

X
Xin Pan 已提交
1093
        """
1094
        if '_parameters' not in self.__dict__:
1095
            raise RuntimeError("super().__init__() should be called firstly.")
1096
        elif not isinstance(name, str):
1097
            raise TypeError(
1098 1099 1100 1101
                "The name of parameter should be a string, but received {}.".format(
                    type(name).__name__
                )
            )
1102 1103 1104 1105
        elif '.' in name:
            raise KeyError(
                "The name of parameter can not contain `.`, "
                "because when you access the newly added parameter in the "
1106 1107
                "form of `self.**.**`, it will cause AttributeError."
            )
1108 1109 1110 1111
        elif name == '':
            raise KeyError("The name of parameter can not be empty.")
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("The parameter '{}' already exists.".format(name))
1112 1113 1114
        elif parameter is not None and not isinstance(
            parameter, framework.Parameter
        ):
1115
            raise TypeError(
1116 1117 1118 1119
                "The parameter to be added should be a Parameter, but received {}.".format(
                    type(parameter).__name__
                )
            )
1120 1121 1122
        else:
            if parameter is None:
                self._parameters[name] = None
1123

1124
            if len(self._loaddict_holder) > 0:
1125 1126 1127 1128 1129
                assert (
                    parameter.name in self._loaddict_holder
                ), "Parameter not found, Can't not find [ {} ] in state_dict".format(
                    parameter.name
                )
H
hong 已提交
1130

1131
                parameter.set_value(self._loaddict_holder[parameter.name])
1132

1133
            self._parameters[name] = parameter
X
Xin Pan 已提交
1134 1135
        return parameter

1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147
    def _set_op_attrs(self, attrs):
        """
        Add customized attribute while append_op. In case of quantization, we want to save
        some attributes into op_desc while exporting inference model by @to_static.

        Arguments:
            attrs(dict): customized attributes that will be added into op_descs.

        NOTE: The interface is only exposed to developers.
        """

        def is_already_registered(is_pre_hook):
1148 1149 1150 1151 1152 1153 1154 1155 1156 1157
            layers_hooks = (
                self._forward_pre_hooks
                if is_pre_hook
                else self._forward_post_hooks
            )
            candidate_hook = (
                record_program_ops_pre_hook
                if is_pre_hook
                else set_op_customized_attrs_post_hook
            )
1158 1159 1160 1161

            already_registed = False
            if layers_hooks:
                last_key = next(reversed(layers_hooks))
1162
                already_registed = layers_hooks[last_key] == candidate_hook
1163 1164 1165 1166

            return already_registed

        if not isinstance(attrs, dict):
1167 1168
            raise TypeError(
                "attrs should be type(dict), but received {}".format(
1169 1170 1171
                    type(attrs).__name__
                )
            )
1172 1173 1174 1175 1176 1177

        # NOTE: Overwrite behavior for same key.
        self._customized_attrs.update(attrs)

        if not is_already_registered(is_pre_hook=True):
            pre_hook_helper = self.register_forward_pre_hook(
1178 1179
                record_program_ops_pre_hook
            )
1180 1181 1182 1183 1184 1185
            assert len(self._op_recorder.hooks) == 0
            self._op_recorder.hooks = [pre_hook_helper]

        # manually register post_hook to ensure it is inserted into the head.
        if not is_already_registered(is_pre_hook=False):
            post_hook_helper = self.register_forward_post_hook(
1186 1187
                set_op_customized_attrs_post_hook
            )
1188
            if len(self._forward_post_hooks) > 1:
1189 1190 1191
                self._forward_post_hooks.move_to_end(
                    post_hook_helper._hook_id, last=False
                )
1192 1193 1194 1195 1196 1197

            assert len(self._op_recorder.hooks) == 1

            # hooks that need to be removed once we finish executing them.
            self._op_recorder.hooks.append(post_hook_helper)

1198 1199 1200 1201 1202 1203
    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)

X
Xin Pan 已提交
1204
    def __getattr__(self, name):
1205 1206 1207
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in self._parameters:
1208
                if in_declarative_mode():
1209
                    return _convert_into_variable(self._parameters[name])
1210 1211 1212 1213 1214 1215 1216 1217
                return self._parameters[name]
        if '_sub_layers' in self.__dict__:
            _sub_layers = self.__dict__['_sub_layers']
            if name in self._sub_layers:
                return self._sub_layers[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
1218
                if in_declarative_mode():
1219
                    return _convert_into_variable(_buffers[name])
1220 1221
                return _buffers[name]
        return object.__getattribute__(self, name)
X
Xin Pan 已提交
1222 1223

    def __setattr__(self, name, value):
S
songyouwei 已提交
1224 1225 1226 1227 1228
        def _remove_if_exist(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

1229 1230
        if isinstance(getattr(type(self), name, None), property):
            object.__setattr__(self, name, value)
1231
        params = self.__dict__.get('_parameters', None)
X
Xin Pan 已提交
1232 1233
        if isinstance(value, framework.Parameter):
            if params is None:
1234
                raise ValueError("super().__init__() should be called first")
H
hong 已提交
1235
            if len(self._loaddict_holder) > 0:
1236 1237 1238 1239 1240
                assert (
                    value.name in self._loaddict_holder
                ), "Parameter not found, Can't not find [ {} ] in state_dict".format(
                    value.name
                )
H
hong 已提交
1241 1242 1243

                value.set_value(self._loaddict_holder[value.name])

1244
            _remove_if_exist(self.__dict__, self._buffers, self._sub_layers)
1245
            params[name] = value
1246 1247 1248
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(
1249 1250 1251 1252
                    "assignment to parameter '{}' should be of type Parameter or None, but got '{}'".format(
                        name, type(value).__name__
                    )
                )
1253
            params[name] = None
X
Xin Pan 已提交
1254
        else:
1255
            layers = self.__dict__.get('_sub_layers', None)
J
Jiabin Yang 已提交
1256
            if isinstance(value, Layer):
1257 1258
                if layers is None:
                    raise ValueError(
1259
                        "super().__init__() should be called first"
1260 1261
                    )

1262
                _remove_if_exist(self.__dict__, self._parameters, self._buffers)
1263 1264 1265 1266
                layers[name] = value
            elif layers is not None and name in layers:
                if value is not None:
                    raise TypeError(
1267 1268 1269 1270
                        "assignment to sublayer '{}' should be of type Layer or None, but got '{}'".format(
                            name, type(value).__name__
                        )
                    )
1271 1272
                layers[name] = None
            else:
1273
                _buffers = self.__dict__.get('_buffers', None)
W
wanghuancoder 已提交
1274
                if isinstance(value, (core.VarBase, core.eager.Tensor)):
1275 1276
                    if _buffers is None:
                        raise ValueError(
1277
                            "super().__init__() should be called first"
1278
                        )
1279 1280 1281
                    _remove_if_exist(
                        self.__dict__, self._parameters, self._sub_layers
                    )
1282 1283 1284 1285
                    # Set persistable=False by default. Only `register_buffer` can
                    # add a persistable buffer.
                    if name not in self._buffers:
                        self._non_persistable_buffer_names_set.add(name)
1286 1287
                    if not value.name:
                        value.name = unique_name.generate('_buffers_' + name)
1288 1289
                    _buffers[name] = value
                elif _buffers is not None and name in _buffers:
1290
                    # Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in
1291 1292 1293 1294
                    # decorated function, such as `self.buffer = new_tensor`. So we update its
                    # value via `assign`.
                    if type(value) == framework.Variable:
                        from paddle import assign
1295

1296 1297 1298 1299
                        # Note(zhhsplendid): the condition below happens in PaddleGan model,
                        # but should all non-Variable _buffers[name] be re-assign? We
                        # should consider it in the future. I current wrote this as
                        # conservative code.
1300 1301 1302
                        if in_declarative_mode() and _buffers[name] is None:
                            raise RuntimeError(
                                'In Dy2stat, self.{0} is a buffer and self.{0} is '
1303 1304 1305 1306 1307 1308 1309 1310
                                'not allowed to be set to Variable when self.{0} is None.'.format(
                                    name
                                )
                            )
                        elif (
                            _buffers[name] is None
                            or type(getattr(self, name)) == core.VarBase
                        ):
1311 1312
                            _buffers[name] = assign(value)
                        else:
1313
                            assign(value, getattr(self, name))
1314
                    elif value is not None:
1315
                        raise TypeError(
1316 1317 1318 1319
                            "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'".format(
                                name, type(value).__name__
                            )
                        )
1320 1321 1322 1323
                    else:
                        # Assigning None will remove the buffer, but if re-assign a new varBase to it,
                        # it will be remarked as a buffer with same `persistable` attribute.
                        _buffers[name] = None
1324 1325
                else:
                    object.__setattr__(self, name, value)
X
Xin Pan 已提交
1326 1327 1328 1329 1330 1331

    def __delattr__(self, name):
        if name in self._parameters:
            del self._parameters[name]
        elif name in self._sub_layers:
            del self._sub_layers[name]
1332 1333 1334
        elif name in self._buffers:
            del self._buffers[name]
            self._non_persistable_buffer_names_set.discard(name)
X
Xin Pan 已提交
1335 1336 1337
        else:
            object.__delattr__(self, name)

1338 1339
    def __dir__(self):
        """
W
wanghuancoder 已提交
1340
        Return a list. Get all parameters, buffers(non-parameter tensors), sublayers, method and attr of Layer.
1341 1342

        Examples:
1343 1344 1345
            .. code-block:: python
                import paddle
                import numpy as np
1346

1347 1348
                class Mylayer(paddle.nn.Layer):
                    def __init__(self):
1349
                        super().__init__()
1350 1351
                        self.linear1 = paddle.nn.Linear(10, 10)
                        self.linear2 = paddle.nn.Linear(5, 5)
C
cnn 已提交
1352
                        self.conv2d = paddle.nn.Conv2D(3, 2, 3)
1353 1354
                        self.embedding = paddle.nn.Embedding(128, 16)
                        self.h_0 = paddle.to_tensor(np.zeros([10, 10]).astype('float32'))
1355

1356 1357 1358 1359
                mylayer = Mylayer()
                print(dir(mylayer))
                # only parts are shown, because of list have too much content
                # ['__call__', '__class__',  ... , 'conv2d', 'embedding', 'h_0', 'linear1', 'linear2', ... , 'sublayers', 'train']
1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371

        """
        method = dir(self.__class__)
        attrs = list(self.__dict__.keys())
        parameters = list(self._parameters.keys())
        sublayers = list(self._sub_layers.keys())
        buffers = list(self._buffers.keys())

        keys = method + attrs + parameters + sublayers + buffers

        return keys

1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400
    def extra_repr(self):
        """
        Extra representation of this layer, you can have custom implementation
        of your own layer.
        """
        return ''

    def __repr__(self):
        extra_lines = []
        extra_repr = self.extra_repr()
        extra_lines = extra_repr.split('\n')
        sublayer_lines = []
        for name, layer in self._sub_layers.items():
            sublayer_str = repr(layer)
            sublayer_str = _addindent(sublayer_str, 2)
            sublayer_lines.append('(' + name + '): ' + sublayer_str)

        final_str = self.__class__.__name__ + '('
        if extra_lines:
            if len(extra_lines) > 1:
                final_str += '\n  ' + '\n  '.join(extra_lines) + '\n'
            elif len(extra_lines) == 1:
                final_str += extra_lines[0]
        if sublayer_lines:
            final_str += '\n  ' + '\n  '.join(sublayer_lines) + '\n'

        final_str += ')'
        return final_str

1401 1402 1403 1404 1405
    def register_state_dict_hook(self, hook):
        hook_remove_helper = HookRemoveHelper(self._state_dict_hooks)
        self._state_dict_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

1406 1407 1408 1409 1410 1411
    def _obtain_parameters_buffers(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
    ):
S
ShenLiang 已提交
1412
        """
1413
        The difference from state_dict() is that state_dict_hook will not be called,
S
ShenLiang 已提交
1414 1415 1416 1417 1418 1419 1420 1421
        but the original types of parameters and buffers will be maintained.
        """
        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
                destination[structured_name_prefix + name] = data
        for name, buffer in self._buffers.items():
1422 1423 1424 1425
            if (
                buffer is not None
                and name not in self._non_persistable_buffer_names_set
            ):
S
ShenLiang 已提交
1426 1427 1428 1429 1430 1431 1432 1433
                destination[structured_name_prefix + name] = buffer

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
                        layer_item._obtain_parameters_buffers(
1434 1435 1436 1437 1438
                            destination_temp,
                            include_sublayers,
                            structured_name_prefix + layer_name + ".",
                        )
                    )
S
ShenLiang 已提交
1439 1440 1441
                    destination = destination_temp
        return destination

1442 1443 1444 1445 1446 1447 1448 1449
    def _state_dict_impl(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        include_non_persistable_buffer=False,
        use_hook=True,
    ):
1450 1451 1452 1453 1454 1455 1456
        """
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
            include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False
1457
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1458 1459 1460 1461 1462 1463 1464 1465 1466
        """

        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
                destination[structured_name_prefix + name] = data
        for name, buffer in self._buffers.items():
            if not include_non_persistable_buffer:
1467 1468 1469 1470
                if (
                    buffer is not None
                    and name not in self._non_persistable_buffer_names_set
                ):
1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481
                    destination[structured_name_prefix + name] = buffer
            else:
                if buffer is not None:
                    destination[structured_name_prefix + name] = buffer

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
                        layer_item._state_dict_impl(
1482 1483
                            destination_temp,
                            include_sublayers,
1484
                            structured_name_prefix + layer_name + ".",
1485 1486 1487 1488
                            include_non_persistable_buffer,
                            use_hook,
                        )
                    )
1489
                    destination = destination_temp
1490 1491 1492 1493 1494
        if use_hook:
            for state_dict_hook in self._state_dict_hooks.values():
                hook_result = state_dict_hook(destination)
                if hook_result is not None:
                    destination = hook_result
1495 1496 1497

        return destination

1498 1499 1500 1501 1502 1503 1504
    def to_static_state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        use_hook=True,
    ):
1505 1506 1507 1508 1509 1510
        '''
        Get all parameters and buffers of current layer and its sub-layers. And set them into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
1511
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1512

1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530
        Retruns:
            dict: a dict contains all the parameters and persistable buffers.

        Examples:
            .. code-block:: python

                import paddle

                emb = paddle.nn.Embedding(10, 10)

                state_dict = emb.to_static_state_dict()
                paddle.save( state_dict, "paddle_dy.pdparams")

        '''
        return self._state_dict_impl(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
1531
            include_non_persistable_buffer=True,
1532 1533 1534 1535 1536 1537 1538 1539 1540 1541
            use_hook=use_hook,
        )

    def state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        use_hook=True,
    ):
H
hong 已提交
1542
        '''
1543
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
H
hong 已提交
1544

1545
        Parameters:
1546 1547
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
1548
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1549

H
hong 已提交
1550
        Retruns:
1551
            dict: a dict contains all the parameters and persistable buffers.
H
hong 已提交
1552 1553

        Examples:
1554 1555
            .. code-block:: python

1556
                import paddle
H
hong 已提交
1557

1558 1559 1560 1561
                emb = paddle.nn.Embedding(10, 10)

                state_dict = emb.state_dict()
                paddle.save( state_dict, "paddle_dy.pdparams")
H
hong 已提交
1562 1563

        '''
1564 1565 1566 1567
        return self._state_dict_impl(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
1568
            include_non_persistable_buffer=False,
1569 1570
            use_hook=use_hook,
        )
1571

1572
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
1573
    def set_state_dict(self, state_dict, use_structured_name=True):
H
hong 已提交
1574
        '''
1575
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
H
hong 已提交
1576

1577
        Parameters:
1578
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
1579
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
H
hong 已提交
1580
                                                  Default: True
H
hong 已提交
1581 1582 1583 1584
        Returns:
            None

        Examples:
1585 1586
            .. code-block:: python

1587
                import paddle
1588

1589
                emb = paddle.nn.Embedding(10, 10)
H
hong 已提交
1590

1591
                state_dict = emb.state_dict()
1592 1593
                paddle.save(state_dict, "paddle_dy.pdparams")
                para_state_dict = paddle.load("paddle_dy.pdparams")
1594
                emb.set_state_dict(para_state_dict)
H
hong 已提交
1595

H
hong 已提交
1596 1597
        '''

1598 1599 1600
        def _check_match(key, param):
            state = state_dict.get(key, None)
            if state is None:
1601
                raise ValueError(
1602 1603 1604 1605 1606 1607 1608 1609 1610 1611
                    "{} is not found in the provided dict.".format(key)
                )
            if isinstance(state, dict) or isinstance(state, list):
                if len(state) != len(param):
                    raise ValueError(
                        "{} receieves the length of {}, "
                        "but the expected shape is {}".format(
                            key, len(state), len(param)
                        )
                    )
S
Steffy-zxf 已提交
1612 1613 1614
                else:
                    return param, state
            else:
1615 1616 1617 1618 1619
                state_shape = (
                    state.shape()
                    if inspect.ismethod(state.shape)
                    else state.shape
                )
S
Steffy-zxf 已提交
1620 1621 1622

                if list(state_shape) != list(param.shape):
                    raise ValueError(
1623 1624 1625 1626
                        "{} receives a shape {}, but the expected shape is {}.".format(
                            key, list(state_shape), list(param.shape)
                        )
                    )
S
Steffy-zxf 已提交
1627
                return param, state
1628 1629

        matched_param_state = []
S
sneaxiy 已提交
1630
        for key, param in self._state_dict_impl(use_hook=False).items():
1631 1632 1633 1634 1635 1636 1637
            key_name = key if use_structured_name else param.name
            try:
                match_res = _check_match(key_name, param)
                matched_param_state.append(match_res)
            except ValueError as err:
                warnings.warn(("Skip loading for {}. ".format(key) + str(err)))

J
Jiabin Yang 已提交
1638
        if _non_static_mode():
1639 1640 1641
            for param, state in matched_param_state:
                param.set_value(state)
        else:
H
hong 已提交
1642

1643 1644 1645 1646 1647 1648 1649
            def _set_var(var, ndarray):
                t = global_scope().find_var(var.name).get_tensor()
                p = t._place()
                if p.is_cpu_place():
                    place = core.CPUPlace()
                elif p.is_cuda_pinned_place():
                    place = core.CUDAPinnedPlace()
1650 1651 1652 1653
                elif p.is_xpu_place():
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.XPUPlace(p.xpu_device_id())
1654 1655 1656 1657 1658 1659
                else:
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.CUDAPlace(p.gpu_device_id())
                t.set(ndarray, place)

1660 1661 1662 1663 1664
            try:
                executor = Executor(_get_device())._default_executor
                # restore parameter states
                core._create_loaded_parameter(
                    [param for param, state in matched_param_state],
1665 1666 1667
                    global_scope(),
                    executor,
                )
1668 1669 1670 1671 1672 1673
                for param, state in matched_param_state:
                    _set_var(param, state)
            except ValueError as e:
                raise ValueError(
                    "This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
                )
1674

C
chentianyu03 已提交
1675 1676 1677 1678 1679
    def to(self, device=None, dtype=None, blocking=None):
        '''
        Cast the parameters and buffers of Layer by the give device, dtype and blocking.

        Parameters:
1680 1681 1682 1683
            device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
            If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None.

1684
            dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
C
chentianyu03 已提交
1685

1686
            blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
C
chentianyu03 已提交
1687
              asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
1688

C
chentianyu03 已提交
1689
        Returns:
1690
            self
C
chentianyu03 已提交
1691 1692 1693 1694

        Examples:
            .. code-block:: python

1695
                # required: skip
C
chentianyu03 已提交
1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720
                import paddle

                linear=paddle.nn.Linear(2, 2)
                linear.weight
                #Parameter containing:
                #Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(dtype='float64')
                linear.weight
                #Tenor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(device='cpu')
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CPUPlace, stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])
                linear.to(device=paddle.CUDAPinnedPlace(), blocking=False)
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CUDAPinnedPlace, stop_gradient=False,
                #       [[-0.04989364, -0.56889004],
                #        [ 0.33960250,  0.96878713]])
1721

1722
        '''
1723 1724 1725 1726 1727 1728 1729
        return self._to_impl(
            device=device,
            dtype=dtype,
            blocking=blocking,
            include_sublayers=True,
            floating_only=False,
        )
1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742

    def _apply(self, func, device, dtype, blocking, include_sublayers=True):
        if include_sublayers:
            for layer in self.children():
                layer._apply(func, device, dtype, blocking, include_sublayers)

        for key, param in self._parameters.items():
            if param is not None:
                with no_grad():
                    param_applied = func(param, device, dtype, blocking)

                if param.grad is not None:
                    with no_grad():
1743 1744 1745
                        grad_applied = func(
                            param._grad_ivar(), device, dtype, blocking
                        )
1746 1747

        for key, buf in self._buffers.items():
1748 1749
            if buf is not None:
                self._buffers[key] = func(buf, device, dtype, blocking)
1750

1751 1752
        self._dtype = dtype

1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768
    def _transform(self, t, device, dtype, blocking):
        if device is None:
            device = t.place
        if dtype is None:
            dtype = t.dtype

        if type(dtype) is not VarDesc.VarType:
            dtype = convert_np_dtype_to_dtype_(dtype)

        # 1. gpu place need to determine whether the memory is sufficient for allocation:
        if t.place.is_gpu_place():
            # for gpu, minimum memory allocation unit is 256 bytes.
            size_dtype = core.size_of_dtype(dtype)
            # Note(zhangbo): Paddle GPU minimum memory allocation unit is 256 bytes, waiting_alloc_memory will comput ‘t’ occupied memory space.
            # Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
            waiting_alloc_memory = (
1769 1770
                ((np.prod(t.shape) * size_dtype) / 256 + 1) * 256 * 1.2
            )
1771 1772 1773
            gpu_memory_available = core.gpu_memory_available()
            if gpu_memory_available < waiting_alloc_memory:
                # Copy param / Tensor to cpu
1774 1775 1776
                t_used = t._copy_to(
                    paddle.CPUPlace(), blocking
                )  # k-v type will error
1777 1778 1779 1780 1781 1782 1783 1784 1785 1786
                # Release mem of t
                t.value().get_tensor()._clear()
            else:
                t_used = t
        else:
            t_used = t

        # 2. cast param / Tensor to dtype
        if dtype is not None and dtype != t_used.dtype:
            with paddle.fluid.framework._dygraph_place_guard(
1787 1788
                place=t_used.place
            ):
1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805
                t_casted = t_used.cast(dtype=dtype)
        else:
            t_casted = t_used

        # 3. Copy casted cpu param / Tensor to device
        if device is not None and not t_casted.place._equals(device):
            new_t = t_casted._copy_to(device, blocking)
        else:
            new_t = t_casted

        # 4. share Tensor to origin param / Tensor
        dst_tensor = t.value().get_tensor()
        src_tensor = new_t.value().get_tensor()
        dst_tensor._share_data_with(src_tensor)

        return t

1806 1807 1808 1809 1810 1811 1812 1813
    def _to_impl(
        self,
        device=None,
        dtype=None,
        blocking=None,
        include_sublayers=True,
        floating_only=False,
    ):
1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825
        '''
        Cast the parameters and buffers of Layer by the give device, dtype and blocking.

        Parameters:
            device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
            If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None.

            dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.

            blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
              asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
1826

1827 1828
            include_sublayers(bool|True, optional): If True, deal with self and all sublayers parameters and buffers, if not only deal with self parameters and buffers. Default: True.

1829 1830
            floating_only(bool|False, optional): If True, only cast all floating point parameters and buffers of Layer by the give device, dtype and blocking.

1831 1832
        Returns:
            self
C
chentianyu03 已提交
1833 1834 1835 1836

        '''

        if device is None and dtype is None and blocking is None:
1837
            return self
C
chentianyu03 已提交
1838 1839 1840 1841

        if device is not None:
            if isinstance(device, str):
                device = paddle.device._convert_to_place(device)
1842 1843 1844 1845 1846 1847 1848 1849 1850
            elif isinstance(
                device,
                (
                    core.CPUPlace,
                    core.CUDAPlace,
                    core.CUDAPinnedPlace,
                    core.XPUPlace,
                ),
            ):
C
chentianyu03 已提交
1851 1852 1853 1854
                pass
            else:
                raise ValueError(
                    "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
1855 1856
                    + type(device).__name__
                )
C
chentianyu03 已提交
1857 1858 1859 1860 1861

        if blocking is None:
            blocking = True
        else:
            assert isinstance(
1862 1863
                blocking, bool
            ), "blocking value error, must be the True, False or None"
C
chentianyu03 已提交
1864 1865

        def transform(t, device, dtype, blocking):
1866 1867 1868
            if floating_only and (not paddle.is_floating_point(t)):
                return t
            return self._transform(t, device, dtype, blocking)
C
chentianyu03 已提交
1869

1870 1871
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
1872
            self._apply(transform, device, dtype, blocking, include_sublayers)
1873

1874
        self._dtype = dtype
1875
        return self
C
chentianyu03 已提交
1876

1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888
    def _startup_program(self):
        """
        Return starup program containing initialization operations of all parameters.

        NOTE(dev): This is a very low level API and only for inner developer.
        """
        startup_program = Program()
        for param in self.parameters():
            param._create_init_op(startup_program.global_block())

        return startup_program

1889 1890 1891
    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict