layers.py 77.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
Xin Pan 已提交
15
import collections
16
import copy
17
import inspect
18 19 20 21 22
import re
import warnings
import weakref

import numpy as np
23

24
import paddle
25
from paddle import profiler
26 27 28 29
from paddle.fluid import core, framework, unique_name
from paddle.fluid.core import VarDesc
from paddle.fluid.dygraph import no_grad
from paddle.fluid.dygraph.base import (
30
    _convert_into_variable,
31 32
    in_declarative_mode,
    program_desc_tracing_guard,
33
)
34
from paddle.fluid.dygraph_utils import _append_activation_in_dygraph
35
from paddle.fluid.executor import Executor, global_scope
36 37
from paddle.fluid.framework import Parameter, Program
from paddle.fluid.framework import _current_expected_place as _get_device
38
from paddle.fluid.framework import (
39
    _global_flags,
40
    convert_np_dtype_to_dtype_,
41
    default_main_program,
42 43
    in_dygraph_mode,
)
44 45 46
from paddle.fluid.layer_helper_base import LayerHelperBase
from paddle.fluid.param_attr import ParamAttr
from paddle.profiler.utils import in_profiler_mode
47
from paddle.utils import deprecated
48

49
__all__ = []
50

51 52 53 54
_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re = re.compile('([a-z])([A-Z])')


55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
def record_program_ops_pre_hook(layer, inputs):
    """
    A pre-hook to mark op numbers before enter layer.forward.
    """
    if not in_dygraph_mode():
        if layer._op_recorder.start < 0:
            layer._op_recorder.start = len(
                default_main_program().current_block().ops
            )
            layer._op_recorder.is_valid = True
        else:
            layer._op_recorder.is_valid = False
            warnings.warn(
                "{} has recorded the op information before. Please check whether you call this layer twice.".format(
                    layer._full_name
                )
            )

    return None


def set_op_customized_attrs_post_hook(layer, inputs, outputs):
    """
    A post-hook to append customized attributes into all operators generated in current layer.
    """
    if not in_dygraph_mode() and layer._op_recorder.is_valid:

        start = layer._op_recorder.start
        end = len(default_main_program().current_block().ops)
        assert start >= 0 and end >= start
        ops = default_main_program().current_block().ops[start:end]

        layer._op_recorder.end = end
        layer._op_recorder.ops = ops

        for op in ops:
            for attr_name, val in layer._customized_attrs.items():
                op._set_attr(attr_name, val)

        # remove pre-hook and post-hook
        for hook_helper in layer._op_recorder.hooks:
            hook_helper.remove()

    return None


101 102 103 104 105 106
def _scope_dist2single(dist_scope):
    mapping = {
        "row_parallel_linear": "linear",
        "column_parallel_linear": "linear",
        "vocab_parallel_embedding": "embedding",
        # "parallel_cross_entropy": "cross_entropy", while mp_layer has parallel_cross_entropy,
S
Shuangchi He 已提交
107
        # but there is no parameters so the mapping of parallel_cross_entropy is not necessary.
108 109 110 111
    }
    return mapping.get(dist_scope, dist_scope)


112 113 114 115
def _convert_camel_to_snake(name):
    s1 = _first_cap_re.sub(r'\1_\2', name)
    return _all_cap_re.sub(r'\1_\2', s1).lower()

116

117 118 119 120 121 122 123 124 125 126 127
def _addindent(string, indent):
    s1 = string.split('\n')
    if len(s1) == 1:
        return string
    s2 = []
    for idx, line in enumerate(s1):
        if idx > 0:
            s2.append(str((indent * ' ') + line))
    return s1[0] + '\n' + '\n'.join(s2)


128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
class LayerObjectHelper(LayerHelperBase):
    def __init__(self, name):
        super().__init__(name, layer_type=name)

    def append_op(
        self,
        type=None,
        inputs=None,
        outputs=None,
        attrs=None,
        stop_gradient=None,
    ):
        """append an operator for this layer object.

           Args:
               type: operator type
               inputs: input variable of the operator
               dtype: data type of this parameter
               is_bias: if this is a bias parameter
               default_initializer: set the default initializer for this parameter

        Returns created parameter Variable.
        """
        return self.main_program.current_block().append_op(
            type=type,
            inputs=inputs,
            outputs=outputs,
            attrs=attrs,
            stop_gradient=stop_gradient,
        )

    def _multiple_input(self, inputs_in):
        inputs = inputs_in
        ret = []
        if isinstance(inputs, (list, tuple)):
            for inp in inputs:
                ret.append(self.to_variable(inp))
        else:
            ret.append(self.to_variable(inputs))
        return ret

    # TODO: make it public when we need it
    def _input(self, inputs_in):
        inputs = self._multiple_input(inputs_in)
        if len(inputs) != 1:
            raise "{0} layer only takes one input in".format(self.layer_type)
        return inputs[0]

    def _multiple_param_attr(self, length, param_attr_in=None):
        param_attr = param_attr_in
        if isinstance(param_attr, ParamAttr):
            param_attr = [param_attr]

        if len(param_attr) != 1 and len(param_attr) != length:
            raise ValueError(
                "parameter number mismatch in {}".format(self.name)
            )
        elif len(param_attr) == 1 and length != 1:
            tmp = [None] * length
            for i in range(length):
                tmp[i] = copy.deepcopy(param_attr[0])
            param_attr = tmp
        return param_attr

    def iter_inputs_and_params(self, inputs_in, param_attr_in=None):
        """Access all inputs and params one by one

           Args:
               inputs_in: inputs to be iter
               param_attr_in: param_attr to be iter

        Returns input, param_attr
        """
        param_attr_in = ParamAttr._to_attr(param_attr_in)
        if isinstance(param_attr_in, bool):
            raise ValueError(
                'Param_attr should not be False in {}'.format(self.name)
            )
        inputs = inputs_in if (inputs_in is not None) else []
        inputs = self._multiple_input(inputs)
        param_attrs = self._multiple_param_attr(len(inputs), param_attr_in)
209
        yield from zip(inputs, param_attrs)
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322

    def input_dtype(self, inputs_in):
        """Get input data type

           Args:
               inputs_in: inputs wanted know the data type

        Returns dtype of the input
        """
        inputs_in = inputs_in if (inputs_in is not None) else []
        inputs = self._multiple_input(inputs_in)
        dtype = None
        for each in inputs:
            if dtype is None:
                dtype = each.dtype
            elif dtype != each.dtype:
                raise ValueError(
                    "Data Type mismatch: %d to %d in %s"
                    % (dtype, each.dtype, self.name)
                )
        return dtype

    def get_parameter(self, name):
        """Get parameter specifically

           Args:
               name: parameter's name

        Returns target parameter
        """
        param = self.main_program.global_block().var(name)
        if not isinstance(param, Parameter):
            raise ValueError(
                "no Parameter name %s found in %s" % (name, self.name)
            )
        return param

    # TODO: this should not be called anymore after all activation func move to Layers
    def append_activation(self, input_var, act=None, use_cudnn=None):
        """Append activation

            Args:
                input_var: the input variable. The len(input_var.shape) is
                larger or equal than 2.
                act: activation type
                use_cudnn: if use cudnn

        Return the Variable of after append activation
        """
        act = act
        if act is None:
            return input_var
        if isinstance(act, str):
            act = {'type': act}
        else:
            raise TypeError(
                str(act) + " should be unicode or str in %s ", self.name
            )

        if (use_cudnn is not None) and use_cudnn:
            act['use_cudnn'] = use_cudnn
        use_mkldnn = _global_flags()["FLAGS_use_mkldnn"]
        if (use_mkldnn is not None) and use_mkldnn:
            act['use_mkldnn'] = use_mkldnn
        act_type = act.pop('type')
        if in_dygraph_mode():
            res = _append_activation_in_dygraph(
                input_var, act_type, use_cudnn, use_mkldnn
            )
            return res
        else:
            tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
            self.append_op(
                type=act_type,
                inputs={"X": [input_var]},
                outputs={"Out": [tmp]},
                attrs=act,
            )
            return tmp

    def is_instance(self, param, cls):
        """Check if the input parameter is instance of input class

            Args:
                param: parameter to be check
                cls: class of the parameter

        Return result of the check (True or False)
        """
        param = param
        if not isinstance(param, cls):
            raise TypeError(
                "The input {0} parameter of method {1} must be {2}, in layer {3}",
                param,
                self.layer_type,
                cls.__name__,
                self.name,
            )


class LayerOpsRecoder:
    """
    Record generated operators information in nn.Layer.
    """

    def __init__(self, start=-1, end=-1, ops=None, is_valid=False, hooks=None):
        self.start = start
        self.end = end
        self.ops = ops
        self.is_valid = is_valid
        self.hooks = hooks


323
class HookRemoveHelper:
324
    """A HookRemoveHelper that can be used to remove hook."""
325 326 327 328 329 330 331 332 333 334 335 336 337 338

    next_hook_id = 0

    def __init__(self, hooks):
        self._hooks_ref = weakref.ref(hooks)
        self._hook_id = HookRemoveHelper.next_hook_id
        HookRemoveHelper.next_hook_id += 1

    def remove(self):
        hooks = self._hooks_ref()
        if hooks is not None and self._hook_id in hooks:
            del hooks[self._hook_id]


339
class Layer:
340 341
    """
    Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
X
Xin Pan 已提交
342

343
    Parameters:
344 345
        name_scope (str, optional): prefix name used by the layer to name parameters.
            If prefix is "my_layer", parameter name in MyLayer
346 347 348
            can be "my_layer_0.w_n", where "w" is the parameter
            base name and "n" is an unique suffix auto-generated.
            If None, prefix name will be snake cased class name. Default: None.
349
        dtype(str, optional): data type of this parameter.
350 351
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
352
                Default: "float32"
353

354 355
    Returns:
        None
356 357 358 359 360 361 362

    Examples:
        .. code-block:: python

            import paddle
            class MyLayer(paddle.nn.Layer):
                def __init__(self):
363
                    super().__init__()
364 365 366 367 368 369 370 371 372 373 374 375
                    self._linear = paddle.nn.Linear(1, 1)
                    self._dropout = paddle.nn.Dropout(p=0.5)
                def forward(self, input):
                    temp = self._linear(input)
                    temp = self._dropout(temp)
                    return temp
            x = paddle.randn([10, 1], 'float32')
            mylayer = MyLayer()
            mylayer.eval()  # set mylayer._dropout to eval mode
            out = mylayer(x)
            mylayer.train()  # set mylayer._dropout to train mode
            out = mylayer(x)
X
Xin Pan 已提交
376
    """
X
Xin Pan 已提交
377

378
    def __init__(self, name_scope=None, dtype="float32"):
379
        self.training = True
380
        if name_scope is None:
381
            name_scope = _convert_camel_to_snake(self.__class__.__name__)
382
            name_scope = _scope_dist2single(name_scope)
383
        self._full_name = unique_name.generate(name_scope)
384
        self._helper = LayerObjectHelper(self._full_name)
X
Xin Pan 已提交
385
        self._built = False
M
minqiyang 已提交
386
        self._dtype = dtype
姜永久 已提交
387
        self._init_in_dynamic_mode = in_dygraph_mode()
388

X
Xin Pan 已提交
389
        self._parameters = collections.OrderedDict()
390 391 392
        # Buffers the variable (not parameter) created in layer
        self._buffers = collections.OrderedDict()
        self._non_persistable_buffer_names_set = set()
X
Xin Pan 已提交
393
        self._sub_layers = collections.OrderedDict()
L
lujun 已提交
394
        self._loaddict_holder = collections.OrderedDict()
395

396 397 398 399
        # Record generated op_descs in this layer
        self._op_recorder = LayerOpsRecoder(ops=[], hooks=[])
        self._customized_attrs = {}

400 401 402
        self._forward_pre_hooks = collections.OrderedDict()
        self._forward_post_hooks = collections.OrderedDict()

403 404 405
        self._casted_by_pure_fp16 = False

        self._state_dict_hooks = collections.OrderedDict()
406 407
        # Records orignal functions after @to_static to support to rollback
        self._original_funcs = collections.OrderedDict()
408

M
minqiyang 已提交
409
    def train(self):
410
        """
U
ustiniankw 已提交
411

412 413 414 415 416
        Sets this Layer and all its sublayers to training mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
417

U
ustiniankw 已提交
418
        Examples:
419 420 421 422 423 424
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
425
                        super().__init__()
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                mylayer.train()  # set mylayer._dropout to train mode
                out = mylayer(x)

441
        """
442 443 444
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
姜永久 已提交
445
        if in_dygraph_mode():
446
            framework._dygraph_tracer().train_mode()
447 448 449
        # Layer-level setting
        self.training = True
        for layer in self.sublayers():
450
            layer.training = True
M
minqiyang 已提交
451 452

    def eval(self):
453 454 455 456 457 458
        """
        Sets this Layer and all its sublayers to evaluation mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
459 460 461 462 463 464 465 466

        Example::
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
467
                        super().__init__()
468 469 470 471 472 473 474 475 476 477 478 479 480 481
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                print(out)

482
        """
483 484 485
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
姜永久 已提交
486
        if in_dygraph_mode():
487
            framework._dygraph_tracer().eval_mode()
488 489 490
        # Layer-level setting
        self.training = False
        for layer in self.sublayers():
491
            layer.training = False
M
minqiyang 已提交
492

L
LielinJiang 已提交
493 494
    def apply(self, fn):
        """
U
ustiniankw 已提交
495

L
LielinJiang 已提交
496 497 498 499 500 501 502
        Applies ``fn`` recursively to every sublayer (as returned by ``.sublayers()``)
        as well as self. Typical use includes initializing the parameters of a model.

        Parameters:
            fn (function): a function to be applied to each sublayer

        Returns:
U
ustiniankw 已提交
503
            Layer, self
L
LielinJiang 已提交
504 505 506 507 508 509

        Example::
            .. code-block:: python

              import paddle
              import paddle.nn as nn
510

L
LielinJiang 已提交
511 512 513 514 515
              net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

              def init_weights(layer):
                  if type(layer) == nn.Linear:
                      print('before init weight:', layer.weight.numpy())
516
                      new_weight = paddle.full(shape=layer.weight.shape, dtype=layer.weight.dtype, fill_value=0.9)
L
LielinJiang 已提交
517 518 519 520 521 522
                      layer.weight.set_value(new_weight)
                      print('after init weight:', layer.weight.numpy())

              net.apply(init_weights)

              print(net.state_dict())
U
ustiniankw 已提交
523

L
LielinJiang 已提交
524
        """
525
        for layer in self.children():
L
LielinJiang 已提交
526 527 528 529 530 531
            layer.apply(fn)

        fn(self)

        return self

X
Xin Pan 已提交
532
    def full_name(self):
U
ustiniankw 已提交
533 534 535
        """

        Full name for this layer, composed by name_scope + "/" + MyLayer.__class__.__name__
X
Xin Pan 已提交
536

537
        Returns:
U
ustiniankw 已提交
538
            str, full name of this layer.
539 540 541 542 543 544 545 546

        Example::
            .. code-block:: python

                import paddle

                class LinearNet(paddle.nn.Layer):
                    def __init__(self):
547
                        super().__init__(name_scope = "demo_linear_net")
548 549 550 551 552 553 554 555
                        self._linear = paddle.nn.Linear(1, 1)

                    def forward(self, x):
                        return self._linear(x)

                linear_net = LinearNet()
                print(linear_net.full_name())   # demo_linear_net_0

X
Xin Pan 已提交
556 557 558
        """
        return self._full_name

559
    def register_forward_post_hook(self, hook):
U
ustiniankw 已提交
560 561 562
        """

        Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed.
563 564 565

        It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively.
        User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.
566

567 568 569 570 571 572
        hook(Layer, input, output) -> None or modified output

        Parameters:
            hook(function): a function registered as a forward post-hook

        Returns:
U
ustiniankw 已提交
573
            HookRemoveHelper, a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
574 575 576 577

        Examples:
            .. code-block:: python

578 579 580 581 582 583
                import paddle
                import numpy as np

                # the forward_post_hook change the output of the layer: output = output * 2
                def forward_post_hook(layer, input, output):
                    # user can use layer, input and output for information statistis tasks
584

585 586
                    # change the output
                    return output * 2
587

588
                linear = paddle.nn.Linear(13, 5)
589

590 591
                # register the hook
                forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
592

593 594
                value1 = np.arange(26).reshape(2, 13).astype("float32")
                in1 = paddle.to_tensor(value1)
595

596
                out0 = linear(in1)
597

598 599 600 601 602 603 604
                # remove the hook
                forward_post_hook_handle.remove()

                out1 = linear(in1)

                # hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
                assert (out0.numpy() == (out1.numpy()) * 2).any()
U
ustiniankw 已提交
605

606 607 608 609 610 611
        """
        hook_remove_helper = HookRemoveHelper(self._forward_post_hooks)
        self._forward_post_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

    def register_forward_pre_hook(self, hook):
U
ustiniankw 已提交
612 613 614
        """

        Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed.
615

616
        It should have the following form, `input` of the `hook` is `input` of the `Layer`,
617
        hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if
618 619 620 621 622 623 624 625 626
        a single value is returned(unless that value is already a tuple).
        User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.

        hook(Layer, input) -> None or modified input

        Parameters:
            hook(function): a function registered as a forward pre-hook

        Returns:
U
ustiniankw 已提交
627
            HookRemoveHelper, a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .
628 629 630 631

        Examples:
            .. code-block:: python

632 633
                import paddle
                import numpy as np
634

635
                # the forward_pre_hook change the input of the layer: input = input * 2
636 637
                def forward_pre_hook(layer, input):
                    # user can use layer and input for information statistis tasks
638

639 640 641
                    # change the input
                    input_return = (input[0] * 2)
                    return input_return
642

643
                linear = paddle.nn.Linear(13, 5)
644

645 646
                # register the hook
                forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)
647

648 649 650
                value0 = np.arange(26).reshape(2, 13).astype("float32")
                in0 = paddle.to_tensor(value0)
                out0 = linear(in0)
651

652 653
                # remove the hook
                forward_pre_hook_handle.remove()
654

655 656 657
                value1 = value0 * 2
                in1 = paddle.to_tensor(value1)
                out1 = linear(in1)
658

659 660
                # hook change the linear's input to input * 2, so out0 is equal to out1.
                assert (out0.numpy() == out1.numpy()).any()
661 662 663 664 665
        """
        hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
        self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

666 667 668 669 670 671 672 673
    def create_parameter(
        self,
        shape,
        attr=None,
        dtype=None,
        is_bias=False,
        default_initializer=None,
    ):
674
        """Create parameters for this layer.
675

676
        Parameters:
677
            shape(list): Shape of the parameter.
678 679
            attr(ParamAttr, optional): Parameter attribute of weight. Please refer to :ref:`api_paddle_ParamAttr`. Default: None.
            dtype(str, optional): Data type of this parameter.
680
                If set str, it can be "bool",  "float16", "float32", "float64",
681 682
                "int8", "int16", "int32", "int64", "uint8" or "uint16". Default: "float32".
            is_bias(bool, optional): if this is a bias parameter. Default: False.
683
            default_initializer(Initializer, optional): the default initializer for this parameter.
684
                If set None, default initializer will be set to paddle.nn.initializer.Xavier and paddle.nn.initializer.Constant
685
                for non-bias and bias parameter, respectively. Default: None.
686

687
        Returns:
688 689 690 691 692 693 694 695 696
            :Tensor, created parameter.

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
697
                        super().__init__()
698 699 700 701 702 703 704 705 706 707 708
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

709
        """
H
hong 已提交
710
        temp_attr = copy.deepcopy(attr)
711
        if isinstance(temp_attr, str) and temp_attr == "":
H
hong 已提交
712
            temp_attr = None
713 714 715 716 717 718 719 720 721
        return self._helper.create_parameter(
            temp_attr, shape, dtype, is_bias, default_initializer
        )

    @deprecated(
        since="2.0.0",
        update_to="paddle.nn.Layer.create_tensor",
        reason="New api in create_tensor, easier to use.",
    )
722
    def create_variable(self, name=None, persistable=None, dtype=None):
W
wanghuancoder 已提交
723 724 725
        """

        Create Tensor for this layer.
726

727
        Parameters:
W
wanghuancoder 已提交
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None

            persistable(bool, optional): if set this tensor persistable. Default: False

            dtype(str, optional): data type of this parameter. If set str, it can be "bool", "float16", "float32", "float64","int8", "int16", "int32", "int64", "uint8" or "uint16". If set None, it will be "float32". Default: None

        Returns:
            Tensor, created Tensor.

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
746
                        super().__init__()
W
wanghuancoder 已提交
747
                        self.linear = paddle.nn.Linear( 10, 10)
748

W
wanghuancoder 已提交
749
                        self.back_var = self.create_variable(name = "linear_tmp_0", dtype=self._dtype)
750

W
wanghuancoder 已提交
751 752 753
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
754

W
wanghuancoder 已提交
755 756 757 758 759 760
                        return out

        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
761 762 763
            var_name = unique_name.generate(
                ".".join([self._full_name, "_generated_var"])
            )
W
wanghuancoder 已提交
764 765 766 767 768

        return self._helper.main_program.current_block().create_var(
            name=var_name,
            persistable=persistable,
            dtype=dtype,
769 770
            type=core.VarDesc.VarType.LOD_TENSOR,
        )
W
wanghuancoder 已提交
771 772 773 774 775 776 777 778 779 780

    # TODO: Add more parameter list when we need them
    def create_tensor(self, name=None, persistable=None, dtype=None):
        """

        Create Tensor for this layer.

        Parameters:
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None
            persistable(bool, optional): if set this tensor persistable. Default: False
781
            dtype(str, optional): data type of this parameter.
782 783
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
784
                If set None, it will be "float32". Default: None
785

786
        Returns:
W
wanghuancoder 已提交
787
            Tensor, created Tensor.
788 789 790 791 792 793 794 795 796 797

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
798
                        super().__init__()
799
                        self.linear = paddle.nn.Linear( 10, 10)
800

W
wanghuancoder 已提交
801
                        self.back_var = self.create_tensor(name = "linear_tmp_0", dtype=self._dtype)
802

803 804 805
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
806

807 808
                        return out

809 810 811 812
        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
813 814 815
            var_name = unique_name.generate(
                ".".join([self._full_name, "_generated_var"])
            )
816 817

        return self._helper.main_program.current_block().create_var(
818 819 820
            name=var_name,
            persistable=persistable,
            dtype=dtype,
821 822
            type=core.VarDesc.VarType.LOD_TENSOR,
        )
823

X
polish  
Xin Pan 已提交
824
    def parameters(self, include_sublayers=True):
U
ustiniankw 已提交
825 826 827
        """

        Returns a list of all Parameters from current layer and its sub-layers.
X
Xin Pan 已提交
828

829
        Returns:
U
ustiniankw 已提交
830
            list of Tensor, a list of Parameters.
831 832 833 834

        Examples:
            .. code-block:: python

U
ustiniankw 已提交
835
                import paddle
836

U
ustiniankw 已提交
837 838
                linear = paddle.nn.Linear(1,1)
                print(linear.parameters())  # print linear_0.w_0 and linear_0.b_0
839

X
Xin Pan 已提交
840
        """
841
        ret = [
842 843 844 845
            param
            for _, param in self.named_parameters(
                include_sublayers=include_sublayers
            )
846
        ]
X
polish  
Xin Pan 已提交
847
        return ret
X
Xin Pan 已提交
848

849
    def children(self):
U
ustiniankw 已提交
850 851 852
        """

        Returns an iterator over immediate children layers.
853 854 855 856 857 858 859

        Yields:
            Layer: a child layer

        Examples:
            .. code-block:: python

860
                import paddle
861

862 863 864 865 866
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)

                layer_list = list(model.children())
867

868
                print(layer_list)   # [<paddle.nn.layer.common.Linear object at 0x7f7b8113f830>, <paddle.nn.layer.common.Linear object at 0x7f7b8113f950>]
869 870 871 872 873 874 875 876 877 878 879 880 881 882 883

        """
        for _, layer in self.named_children():
            yield layer

    def named_children(self):
        """Returns an iterator over immediate children layers, yielding both
        the name of the layer as well as the layer itself.

        Yields:
            (string, Layer): Tuple containing a name and child layer

        Examples:
            .. code-block:: python

884
                import paddle
885

886 887 888 889 890 891 892
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)
                for prefix, layer in model.named_children():
                    print(prefix, layer)
                    # ('0', <paddle.nn.layer.common.Linear object at 0x7fb61ed85830>)
                    # ('1', <paddle.nn.layer.common.Linear object at 0x7fb61ed85950>)
893 894 895 896 897 898 899 900

        """
        memo = set()
        for name, layer in self._sub_layers.items():
            if layer is not None and layer not in memo:
                memo.add(layer)
                yield name, layer

J
Jiabin Yang 已提交
901
    def sublayers(self, include_self=False):
U
ustiniankw 已提交
902 903 904
        """

        Returns a list of sub layers.
X
Xin Pan 已提交
905

906
        Parameters:
J
Jiabin Yang 已提交
907
            include_self(bool, optional): Whether return self as sublayers. Default: False
X
Xin Pan 已提交
908

909
        Returns:
U
ustiniankw 已提交
910
            list of Layer, a list of sub layers.
911 912 913 914 915 916 917 918

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
919
                        super().__init__()
920 921 922 923 924 925 926 927 928 929 930
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                mylayer = MyLayer()
                print(mylayer.sublayers())  # [<paddle.nn.layer.common.Linear object at 0x7f44b58977d0>, <paddle.nn.layer.common.Dropout object at 0x7f44b58978f0>]

X
Xin Pan 已提交
931
        """
932 933
        ret = [
            layer
J
Jiabin Yang 已提交
934
            for _, layer in self.named_sublayers(include_self=include_self)
935
        ]
X
Xin Pan 已提交
936 937
        return ret

938 939 940 941 942 943 944 945 946 947 948 949 950 951 952
    def named_parameters(self, prefix='', include_sublayers=True):
        """
        Returns an iterator over all parameters in the Layer, yielding tuple of name and parameter.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_sublayers(bool, optional): Whether include the parameters of sublayers.
                If True, also include the named parameters from sublayers. Default: True.

        Yields:
            (string, Parameter): Tuple of name and Parameter

        Examples:
            .. code-block:: python

953
                import paddle
954

955 956 957 958 959
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for name, param in model.named_parameters():
                    print(name, param)
960 961 962

        """
        params_set = set()
963 964 965 966 967
        named_sublayers = (
            self.named_sublayers(prefix=prefix, include_self=True)
            if include_sublayers
            else zip([prefix], [self])
        )
968 969 970 971 972 973 974 975 976
        for layer_prefix, sublayer in named_sublayers:
            params = sublayer._parameters.items()
            for key, param in params:
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, param

J
Jiabin Yang 已提交
977
    def named_sublayers(self, prefix='', include_self=False, layers_set=None):
978 979 980 981 982 983 984
        """
        Returns an iterator over all sublayers in the Layer, yielding tuple of name and sublayer.
        The duplicate sublayer will only be yielded once.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_self(bool, optional): Whether include the Layer itself. Default: False.
985
            layers_set(set, optional): The set to record duplicate sublayers. Default: None.
986 987 988 989 990 991 992

        Yields:
            (string, Layer): Tuple of name and Layer

        Examples:
            .. code-block:: python

993
                import paddle
994

995 996 997 998 999
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
1000 1001 1002 1003 1004 1005 1006

        """
        if layers_set is None:
            layers_set = set()
        if include_self and self not in layers_set:
            layers_set.add(self)
            yield prefix, self
J
Jiabin Yang 已提交
1007 1008 1009 1010
        for key, layer in self._sub_layers.items():
            if layer is None:
                continue
            layer_prefix = prefix + ('.' if prefix else '') + key
1011 1012 1013
            for p, l in layer.named_sublayers(
                prefix=layer_prefix, include_self=True, layers_set=layers_set
            ):
J
Jiabin Yang 已提交
1014
                yield p, l
1015

1016
    def register_buffer(self, name, tensor, persistable=True):
1017
        """
1018
        Registers a tensor as buffer into the layer.
1019

1020
        `buffer` is a non-trainable tensor and will not be updated by optimizer,
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
        but is necessary for evaluation and inference. For example, the mean and variance in BatchNorm layers.
        The registered buffer is persistable by default, and will be saved into
        `state_dict` alongside parameters. If set persistable=False, it registers
        a non-persistable buffer, so that it will not be a part of `state_dict` .

        Buffers can be accessed as attributes using given names.

        Parameters:
            name (string): name of the buffer. The buffer can be accessed
                from this layer using the given name
1031
            tensor (Tensor): the tensor to be registered as buffer.
1032 1033 1034 1035 1036
            persistable (bool): whether the buffer is part of this layer's
                state_dict.

        Returns:
            None
1037

1038 1039 1040 1041
        Examples:
            .. code-block:: python

                import numpy as np
1042
                import paddle
1043

1044 1045 1046 1047 1048 1049 1050
                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                # get the buffer by attribute.
                print(linear.buf_name)
1051 1052 1053 1054

        """

        if '_buffers' not in self.__dict__:
1055
            raise ValueError("super().__init__() should be called first")
1056
        elif not isinstance(name, str):
1057
            raise TypeError(
1058 1059 1060 1061
                "The name of buffer should be a string, but received {}.".format(
                    type(name).__name__
                )
            )
1062
        elif '.' in name:
1063 1064 1065
            raise KeyError(
                "The name of buffer can not contain `.`, "
                "because when you access the newly added buffer in the "
1066 1067
                "form of `self.**.**`, it will cause AttributeError."
            )
1068 1069 1070 1071
        elif name == '':
            raise KeyError("The name of buffer can not be empty.")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists.".format(name))
1072 1073 1074
        elif tensor is not None and not (
            type(tensor) == core.VarBase or type(tensor) == core.eager.Tensor
        ):
1075
            raise TypeError(
1076 1077 1078 1079
                "The registered buffer should be a Paddle.Tensor, but received {}.".format(
                    type(tensor).__name__
                )
            )
1080
        else:
1081
            self._buffers[name] = tensor
1082 1083 1084 1085 1086 1087 1088
            if persistable:
                self._non_persistable_buffer_names_set.discard(name)
            else:
                self._non_persistable_buffer_names_set.add(name)

    def buffers(self, include_sublayers=True):
        """
U
ustiniankw 已提交
1089

1090 1091 1092 1093 1094 1095
        Returns a list of all buffers from current layer and its sub-layers.

        Parameters:
            include_sublayers(bool, optional): Whether include the buffers of sublayers. If True, also include the buffers from sublayers. Default: True

        Returns:
U
ustiniankw 已提交
1096
            list of Tensor, a list of buffers.
1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110

        Examples:
            .. code-block:: python

                import numpy as np
                import paddle

                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                print(linear.buffers())     # == print([linear.buf_name])

1111 1112
        """
        ret = [
1113 1114 1115 1116
            buffer
            for _, buffer in self.named_buffers(
                include_sublayers=include_sublayers
            )
1117 1118 1119 1120 1121
        ]
        return ret

    def named_buffers(self, prefix='', include_sublayers=True):
        """
1122
        Returns an iterator over all buffers in the Layer, yielding tuple of name and Tensor.
1123 1124 1125 1126 1127 1128 1129

        Parameters:
            prefix(str, optional): Prefix to prepend to all buffer names. Default: ''.
            include_sublayers(bool, optional): Whether include the buffers of sublayers.
                If True, also include the named buffers from sublayers. Default: True.

        Yields:
1130
            (string, Tensor): Tuple of name and tensor
1131 1132 1133 1134 1135

        Examples:
            .. code-block:: python

                import numpy as np
1136
                import paddle
1137

1138 1139 1140 1141
                fc1 = paddle.nn.Linear(10, 3)
                buffer1 = paddle.to_tensor(np.array([0]).astype("float32"))
                # register a tensor as buffer by specific `persistable`
                fc1.register_buffer("buf_name_1", buffer1, persistable=True)
1142

1143 1144 1145 1146 1147
                fc2 = paddle.nn.Linear(3, 10)
                buffer2 = paddle.to_tensor(np.array([1]).astype("float32"))
                # register a buffer by assigning an attribute with Tensor.
                # The `persistable` can only be False by this way.
                fc2.buf_name_2 = buffer2
1148

1149
                model = paddle.nn.Sequential(fc1, fc2)
1150

1151 1152 1153
                # get all named buffers
                for name, buffer in model.named_buffers():
                    print(name, buffer)
1154 1155 1156

        """
        buffers_set = set()
1157 1158 1159 1160 1161
        named_sublayers = (
            self.named_sublayers(prefix=prefix, include_self=True)
            if include_sublayers
            else zip([prefix], [self])
        )
1162 1163 1164 1165 1166 1167 1168 1169 1170
        for layer_prefix, sublayer in named_sublayers:
            buffers = sublayer._buffers.items()
            for key, buffer in buffers:
                if buffer is None or buffer in buffers_set:
                    continue
                buffers_set.add(buffer)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, buffer

X
Xin Pan 已提交
1171
    def clear_gradients(self):
1172 1173
        """
        Clear the gradients of all parameters for this layer.
1174

1175 1176
        Returns:
            None
1177

1178 1179 1180
        Examples:
            .. code-block:: python

1181
                import paddle
1182 1183
                import numpy as np

1184 1185 1186 1187 1188 1189 1190 1191 1192
                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.to_tensor(value)
                linear = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01,
                                            parameters=linear.parameters())
                out = linear(a)
                out.backward()
                adam.step()
                linear.clear_gradients()
1193 1194

        """
X
Xin Pan 已提交
1195
        for p in self.parameters():
1196 1197
            if p.trainable:
                p.clear_gradient()
X
Xin Pan 已提交
1198

1199
    def _build_once(self, *args, **kwargs):
1200 1201
        pass

1202
    def _dygraph_call_func(self, *inputs, **kwargs):
Q
qizhaoaoe 已提交
1203 1204
        from paddle.distributed import parallel_helper

1205 1206 1207 1208
        for forward_pre_hook in self._forward_pre_hooks.values():
            hook_result = forward_pre_hook(self, inputs)
            if hook_result is not None:
                if not isinstance(hook_result, tuple):
1209
                    hook_result = (hook_result,)
1210 1211 1212 1213 1214 1215 1216 1217 1218
                inputs = hook_result

        if not self._built:
            with program_desc_tracing_guard(False):
                self._build_once(*inputs, **kwargs)

                # TODO(liuyuhui) Only xpu broadcast parameters here.
                # The other device is to call _sync_params_buffers in DataParallel
                # to realize the parameter synchronization among multiply cards.
1219 1220 1221 1222
                if (
                    parallel_helper._is_data_parallel_mode()
                    and paddle.is_compiled_with_xpu()
                ):
1223
                    parallel_helper._broadcast_parameters(
1224 1225
                        self._parameters.values()
                    )
1226 1227 1228

            self._built = True

1229
        if in_profiler_mode():
1230 1231 1232
            with profiler.RecordEvent(
                self.__class__.__name__, profiler.TracerEventType.Forward
            ):
1233 1234
                outputs = self.forward(*inputs, **kwargs)
        else:
C
chenjian 已提交
1235
            outputs = self.forward(*inputs, **kwargs)
1236 1237 1238 1239 1240 1241 1242 1243

        for forward_post_hook in self._forward_post_hooks.values():
            hook_result = forward_post_hook(self, inputs, outputs)
            if hook_result is not None:
                outputs = hook_result

        return outputs

1244
    def __call__(self, *inputs, **kwargs):
1245 1246 1247 1248 1249 1250 1251 1252
        if (
            (not in_declarative_mode())
            and (not self._forward_pre_hooks)
            and (not self._forward_post_hooks)
            and (not self._built)
            and in_dygraph_mode()
            and (not in_profiler_mode())
        ):
1253 1254 1255 1256
            self._build_once(*inputs, **kwargs)
            return self.forward(*inputs, **kwargs)
        else:
            return self._dygraph_call_func(*inputs, **kwargs)
M
minqiyang 已提交
1257

1258
    def forward(self, *inputs, **kwargs):
1259 1260 1261 1262 1263 1264 1265 1266
        """
        Defines the computation performed at every call.
        Should be overridden by all subclasses.

        Parameters:
            *inputs(tuple): unpacked tuple arguments
            **kwargs(dict): unpacked dict arguments
        """
1267
        raise NotImplementedError
X
Xin Pan 已提交
1268 1269 1270 1271

    def backward(self, *inputs):
        raise ValueError("Layer shouldn't implement backward")

X
Xin Pan 已提交
1272
    def add_sublayer(self, name, sublayer):
U
ustiniankw 已提交
1273 1274 1275
        """

        Adds a sub Layer instance.
X
Xin Pan 已提交
1276

1277
        Added sublayer can be accessed by self.name
X
Xin Pan 已提交
1278

1279 1280 1281
        Parameters:
            name(str): name of this sublayer.
            sublayer(Layer): an instance of Layer.
X
Xin Pan 已提交
1282
        Returns:
U
ustiniankw 已提交
1283
            Layer, the sublayer passed in.
1284

1285 1286 1287 1288 1289 1290 1291
        Examples:
            .. code-block:: python

                import paddle

                class MySequential(paddle.nn.Layer):
                    def __init__(self, *layers):
1292
                        super().__init__()
1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309
                        if len(layers) > 0 and isinstance(layers[0], tuple):
                            for name, layer in layers:
                                self.add_sublayer(name, layer)
                        else:
                            for idx, layer in enumerate(layers):
                                self.add_sublayer(str(idx), layer)

                    def forward(self, input):
                        for layer in self._sub_layers.values():
                            input = layer(input)
                        return input

                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = MySequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
U
ustiniankw 已提交
1310

X
Xin Pan 已提交
1311
        """
1312
        assert isinstance(sublayer, Layer) or sublayer is None
1313

X
Xin Pan 已提交
1314 1315 1316 1317 1318 1319
        self._sub_layers[name] = sublayer
        return sublayer

    def add_parameter(self, name, parameter):
        """Adds a Parameter instance.

1320
        Added parameter can be accessed by self.name
X
Xin Pan 已提交
1321

1322 1323 1324
        Parameters:
            name(str): name of this sublayer.
            parameter(Parameter): an instance of Parameter.
X
Xin Pan 已提交
1325
        Returns:
U
ustiniankw 已提交
1326
            Parameter, the parameter passed in.
1327 1328 1329 1330 1331 1332 1333
        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
1334
                        super().__init__()
1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

X
Xin Pan 已提交
1346
        """
1347
        if '_parameters' not in self.__dict__:
1348
            raise RuntimeError("super().__init__() should be called firstly.")
1349
        elif not isinstance(name, str):
1350
            raise TypeError(
1351 1352 1353 1354
                "The name of parameter should be a string, but received {}.".format(
                    type(name).__name__
                )
            )
1355 1356 1357 1358
        elif '.' in name:
            raise KeyError(
                "The name of parameter can not contain `.`, "
                "because when you access the newly added parameter in the "
1359 1360
                "form of `self.**.**`, it will cause AttributeError."
            )
1361 1362 1363 1364
        elif name == '':
            raise KeyError("The name of parameter can not be empty.")
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("The parameter '{}' already exists.".format(name))
1365 1366 1367
        elif parameter is not None and not isinstance(
            parameter, framework.Parameter
        ):
1368
            raise TypeError(
1369 1370 1371 1372
                "The parameter to be added should be a Parameter, but received {}.".format(
                    type(parameter).__name__
                )
            )
1373 1374 1375
        else:
            if parameter is None:
                self._parameters[name] = None
1376

1377
            if len(self._loaddict_holder) > 0:
1378 1379 1380 1381 1382
                assert (
                    parameter.name in self._loaddict_holder
                ), "Parameter not found, Can't not find [ {} ] in state_dict".format(
                    parameter.name
                )
H
hong 已提交
1383

1384
                parameter.set_value(self._loaddict_holder[parameter.name])
1385

1386
            self._parameters[name] = parameter
X
Xin Pan 已提交
1387 1388
        return parameter

1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400
    def _set_op_attrs(self, attrs):
        """
        Add customized attribute while append_op. In case of quantization, we want to save
        some attributes into op_desc while exporting inference model by @to_static.

        Arguments:
            attrs(dict): customized attributes that will be added into op_descs.

        NOTE: The interface is only exposed to developers.
        """

        def is_already_registered(is_pre_hook):
1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
            layers_hooks = (
                self._forward_pre_hooks
                if is_pre_hook
                else self._forward_post_hooks
            )
            candidate_hook = (
                record_program_ops_pre_hook
                if is_pre_hook
                else set_op_customized_attrs_post_hook
            )
1411 1412 1413 1414

            already_registed = False
            if layers_hooks:
                last_key = next(reversed(layers_hooks))
1415
                already_registed = layers_hooks[last_key] == candidate_hook
1416 1417 1418 1419

            return already_registed

        if not isinstance(attrs, dict):
1420 1421
            raise TypeError(
                "attrs should be type(dict), but received {}".format(
1422 1423 1424
                    type(attrs).__name__
                )
            )
1425 1426 1427 1428 1429 1430

        # NOTE: Overwrite behavior for same key.
        self._customized_attrs.update(attrs)

        if not is_already_registered(is_pre_hook=True):
            pre_hook_helper = self.register_forward_pre_hook(
1431 1432
                record_program_ops_pre_hook
            )
1433 1434 1435 1436 1437 1438
            assert len(self._op_recorder.hooks) == 0
            self._op_recorder.hooks = [pre_hook_helper]

        # manually register post_hook to ensure it is inserted into the head.
        if not is_already_registered(is_pre_hook=False):
            post_hook_helper = self.register_forward_post_hook(
1439 1440
                set_op_customized_attrs_post_hook
            )
1441
            if len(self._forward_post_hooks) > 1:
1442 1443 1444
                self._forward_post_hooks.move_to_end(
                    post_hook_helper._hook_id, last=False
                )
1445 1446 1447 1448 1449 1450

            assert len(self._op_recorder.hooks) == 1

            # hooks that need to be removed once we finish executing them.
            self._op_recorder.hooks.append(post_hook_helper)

1451 1452 1453 1454 1455 1456
    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)

X
Xin Pan 已提交
1457
    def __getattr__(self, name):
1458 1459 1460
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in self._parameters:
1461
                if in_declarative_mode():
1462
                    return _convert_into_variable(self._parameters[name])
1463 1464 1465 1466 1467 1468 1469 1470
                return self._parameters[name]
        if '_sub_layers' in self.__dict__:
            _sub_layers = self.__dict__['_sub_layers']
            if name in self._sub_layers:
                return self._sub_layers[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
1471
                if in_declarative_mode():
1472
                    return _convert_into_variable(_buffers[name])
1473 1474
                return _buffers[name]
        return object.__getattribute__(self, name)
X
Xin Pan 已提交
1475 1476

    def __setattr__(self, name, value):
S
songyouwei 已提交
1477 1478 1479 1480 1481
        def _remove_if_exist(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

1482 1483
        if isinstance(getattr(type(self), name, None), property):
            object.__setattr__(self, name, value)
1484
        params = self.__dict__.get('_parameters', None)
X
Xin Pan 已提交
1485 1486
        if isinstance(value, framework.Parameter):
            if params is None:
1487
                raise ValueError("super().__init__() should be called first")
H
hong 已提交
1488
            if len(self._loaddict_holder) > 0:
1489 1490 1491 1492 1493
                assert (
                    value.name in self._loaddict_holder
                ), "Parameter not found, Can't not find [ {} ] in state_dict".format(
                    value.name
                )
H
hong 已提交
1494 1495 1496

                value.set_value(self._loaddict_holder[value.name])

1497
            _remove_if_exist(self.__dict__, self._buffers, self._sub_layers)
1498
            params[name] = value
1499 1500 1501
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(
1502 1503 1504 1505
                    "assignment to parameter '{}' should be of type Parameter or None, but got '{}'".format(
                        name, type(value).__name__
                    )
                )
1506
            params[name] = None
X
Xin Pan 已提交
1507
        else:
1508
            layers = self.__dict__.get('_sub_layers', None)
J
Jiabin Yang 已提交
1509
            if isinstance(value, Layer):
1510 1511
                if layers is None:
                    raise ValueError(
1512
                        "super().__init__() should be called first"
1513 1514
                    )

1515
                _remove_if_exist(self.__dict__, self._parameters, self._buffers)
1516 1517 1518 1519
                layers[name] = value
            elif layers is not None and name in layers:
                if value is not None:
                    raise TypeError(
1520 1521 1522 1523
                        "assignment to sublayer '{}' should be of type Layer or None, but got '{}'".format(
                            name, type(value).__name__
                        )
                    )
1524 1525
                layers[name] = None
            else:
1526
                _buffers = self.__dict__.get('_buffers', None)
W
wanghuancoder 已提交
1527
                if isinstance(value, (core.VarBase, core.eager.Tensor)):
1528 1529
                    if _buffers is None:
                        raise ValueError(
1530
                            "super().__init__() should be called first"
1531
                        )
1532 1533 1534
                    _remove_if_exist(
                        self.__dict__, self._parameters, self._sub_layers
                    )
1535 1536 1537 1538
                    # Set persistable=False by default. Only `register_buffer` can
                    # add a persistable buffer.
                    if name not in self._buffers:
                        self._non_persistable_buffer_names_set.add(name)
1539 1540
                    if not value.name:
                        value.name = unique_name.generate('_buffers_' + name)
1541 1542
                    _buffers[name] = value
                elif _buffers is not None and name in _buffers:
1543
                    # Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in
1544 1545 1546 1547
                    # decorated function, such as `self.buffer = new_tensor`. So we update its
                    # value via `assign`.
                    if type(value) == framework.Variable:
                        from paddle import assign
1548

1549 1550 1551 1552
                        # Note(zhhsplendid): the condition below happens in PaddleGan model,
                        # but should all non-Variable _buffers[name] be re-assign? We
                        # should consider it in the future. I current wrote this as
                        # conservative code.
1553 1554 1555
                        if in_declarative_mode() and _buffers[name] is None:
                            raise RuntimeError(
                                'In Dy2stat, self.{0} is a buffer and self.{0} is '
1556 1557 1558 1559 1560 1561 1562 1563
                                'not allowed to be set to Variable when self.{0} is None.'.format(
                                    name
                                )
                            )
                        elif (
                            _buffers[name] is None
                            or type(getattr(self, name)) == core.VarBase
                        ):
1564 1565
                            _buffers[name] = assign(value)
                        else:
1566
                            assign(value, getattr(self, name))
1567
                    elif value is not None:
1568
                        raise TypeError(
1569 1570 1571 1572
                            "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'".format(
                                name, type(value).__name__
                            )
                        )
1573 1574 1575 1576
                    else:
                        # Assigning None will remove the buffer, but if re-assign a new varBase to it,
                        # it will be remarked as a buffer with same `persistable` attribute.
                        _buffers[name] = None
1577 1578
                else:
                    object.__setattr__(self, name, value)
X
Xin Pan 已提交
1579 1580 1581 1582 1583 1584

    def __delattr__(self, name):
        if name in self._parameters:
            del self._parameters[name]
        elif name in self._sub_layers:
            del self._sub_layers[name]
1585 1586 1587
        elif name in self._buffers:
            del self._buffers[name]
            self._non_persistable_buffer_names_set.discard(name)
X
Xin Pan 已提交
1588 1589 1590
        else:
            object.__delattr__(self, name)

1591 1592
    def __dir__(self):
        """
W
wanghuancoder 已提交
1593
        Return a list. Get all parameters, buffers(non-parameter tensors), sublayers, method and attr of Layer.
1594 1595

        Examples:
1596 1597 1598
            .. code-block:: python
                import paddle
                import numpy as np
1599

1600 1601
                class Mylayer(paddle.nn.Layer):
                    def __init__(self):
1602
                        super().__init__()
1603 1604
                        self.linear1 = paddle.nn.Linear(10, 10)
                        self.linear2 = paddle.nn.Linear(5, 5)
C
cnn 已提交
1605
                        self.conv2d = paddle.nn.Conv2D(3, 2, 3)
1606 1607
                        self.embedding = paddle.nn.Embedding(128, 16)
                        self.h_0 = paddle.to_tensor(np.zeros([10, 10]).astype('float32'))
1608

1609 1610 1611 1612
                mylayer = Mylayer()
                print(dir(mylayer))
                # only parts are shown, because of list have too much content
                # ['__call__', '__class__',  ... , 'conv2d', 'embedding', 'h_0', 'linear1', 'linear2', ... , 'sublayers', 'train']
1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624

        """
        method = dir(self.__class__)
        attrs = list(self.__dict__.keys())
        parameters = list(self._parameters.keys())
        sublayers = list(self._sub_layers.keys())
        buffers = list(self._buffers.keys())

        keys = method + attrs + parameters + sublayers + buffers

        return keys

1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653
    def extra_repr(self):
        """
        Extra representation of this layer, you can have custom implementation
        of your own layer.
        """
        return ''

    def __repr__(self):
        extra_lines = []
        extra_repr = self.extra_repr()
        extra_lines = extra_repr.split('\n')
        sublayer_lines = []
        for name, layer in self._sub_layers.items():
            sublayer_str = repr(layer)
            sublayer_str = _addindent(sublayer_str, 2)
            sublayer_lines.append('(' + name + '): ' + sublayer_str)

        final_str = self.__class__.__name__ + '('
        if extra_lines:
            if len(extra_lines) > 1:
                final_str += '\n  ' + '\n  '.join(extra_lines) + '\n'
            elif len(extra_lines) == 1:
                final_str += extra_lines[0]
        if sublayer_lines:
            final_str += '\n  ' + '\n  '.join(sublayer_lines) + '\n'

        final_str += ')'
        return final_str

1654 1655 1656 1657 1658
    def register_state_dict_hook(self, hook):
        hook_remove_helper = HookRemoveHelper(self._state_dict_hooks)
        self._state_dict_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

1659 1660 1661 1662 1663 1664
    def _obtain_parameters_buffers(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
    ):
S
ShenLiang 已提交
1665
        """
1666
        The difference from state_dict() is that state_dict_hook will not be called,
S
ShenLiang 已提交
1667 1668 1669 1670 1671 1672 1673 1674
        but the original types of parameters and buffers will be maintained.
        """
        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
                destination[structured_name_prefix + name] = data
        for name, buffer in self._buffers.items():
1675 1676 1677 1678
            if (
                buffer is not None
                and name not in self._non_persistable_buffer_names_set
            ):
S
ShenLiang 已提交
1679 1680 1681 1682 1683 1684 1685 1686
                destination[structured_name_prefix + name] = buffer

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
                        layer_item._obtain_parameters_buffers(
1687 1688 1689 1690 1691
                            destination_temp,
                            include_sublayers,
                            structured_name_prefix + layer_name + ".",
                        )
                    )
S
ShenLiang 已提交
1692 1693 1694
                    destination = destination_temp
        return destination

1695 1696 1697 1698 1699 1700 1701 1702
    def _state_dict_impl(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        include_non_persistable_buffer=False,
        use_hook=True,
    ):
1703 1704 1705 1706 1707 1708 1709
        """
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
            include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False
1710
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1711 1712 1713 1714 1715 1716 1717 1718 1719
        """

        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
                destination[structured_name_prefix + name] = data
        for name, buffer in self._buffers.items():
            if not include_non_persistable_buffer:
1720 1721 1722 1723
                if (
                    buffer is not None
                    and name not in self._non_persistable_buffer_names_set
                ):
1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
                    destination[structured_name_prefix + name] = buffer
            else:
                if buffer is not None:
                    destination[structured_name_prefix + name] = buffer

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
                        layer_item._state_dict_impl(
1735 1736
                            destination_temp,
                            include_sublayers,
1737
                            structured_name_prefix + layer_name + ".",
1738 1739 1740 1741
                            include_non_persistable_buffer,
                            use_hook,
                        )
                    )
1742
                    destination = destination_temp
1743 1744 1745 1746 1747
        if use_hook:
            for state_dict_hook in self._state_dict_hooks.values():
                hook_result = state_dict_hook(destination)
                if hook_result is not None:
                    destination = hook_result
1748 1749 1750

        return destination

1751 1752 1753 1754 1755 1756 1757
    def to_static_state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        use_hook=True,
    ):
1758
        '''
U
ustiniankw 已提交
1759

1760 1761 1762 1763 1764
        Get all parameters and buffers of current layer and its sub-layers. And set them into a dict

        Parameters:
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
1765
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1766

1767
        Retruns:
U
ustiniankw 已提交
1768
            dict, a dict contains all the parameters and persistable buffers.
1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784

        Examples:
            .. code-block:: python

                import paddle

                emb = paddle.nn.Embedding(10, 10)

                state_dict = emb.to_static_state_dict()
                paddle.save( state_dict, "paddle_dy.pdparams")

        '''
        return self._state_dict_impl(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
1785
            include_non_persistable_buffer=True,
1786 1787 1788 1789 1790 1791 1792 1793 1794 1795
            use_hook=use_hook,
        )

    def state_dict(
        self,
        destination=None,
        include_sublayers=True,
        structured_name_prefix="",
        use_hook=True,
    ):
H
hong 已提交
1796
        '''
1797
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
H
hong 已提交
1798

1799
        Parameters:
1800 1801
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
1802
            use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
1803

H
hong 已提交
1804
        Retruns:
1805
            dict: a dict contains all the parameters and persistable buffers.
H
hong 已提交
1806 1807

        Examples:
1808 1809
            .. code-block:: python

1810
                import paddle
H
hong 已提交
1811

1812 1813 1814 1815
                emb = paddle.nn.Embedding(10, 10)

                state_dict = emb.state_dict()
                paddle.save( state_dict, "paddle_dy.pdparams")
H
hong 已提交
1816 1817

        '''
1818 1819 1820 1821
        return self._state_dict_impl(
            destination=destination,
            include_sublayers=include_sublayers,
            structured_name_prefix=structured_name_prefix,
1822
            include_non_persistable_buffer=False,
1823 1824
            use_hook=use_hook,
        )
1825

1826
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
1827
    def set_state_dict(self, state_dict, use_structured_name=True):
H
hong 已提交
1828
        '''
1829
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
H
hong 已提交
1830

1831
        Parameters:
1832
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
1833
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
H
hong 已提交
1834
                                                  Default: True
H
hong 已提交
1835
        Returns:
1836 1837
            missing_keys(list):A list of str containing the missing keys
            unexpected_keys(list):A list of str containing the unexpected keys
H
hong 已提交
1838 1839

        Examples:
1840 1841
            .. code-block:: python

1842
                import paddle
1843

1844
                emb = paddle.nn.Embedding(10, 10)
H
hong 已提交
1845

1846
                state_dict = emb.state_dict()
1847 1848
                paddle.save(state_dict, "paddle_dy.pdparams")
                para_state_dict = paddle.load("paddle_dy.pdparams")
1849
                emb.set_state_dict(para_state_dict)
H
hong 已提交
1850

H
hong 已提交
1851
        '''
1852 1853 1854
        missing_keys = []
        match_keys = set()
        unexpected_keys = []
H
hong 已提交
1855

1856 1857 1858
        def _check_match(key, param):
            state = state_dict.get(key, None)
            if state is None:
1859
                missing_keys.append(key)
1860
                raise ValueError(
1861 1862
                    "{} is not found in the provided dict.".format(key)
                )
1863
            if isinstance(state, (dict, list)):
1864
                if len(state) != len(param):
1865
                    missing_keys.append(key)
1866 1867 1868 1869 1870 1871
                    raise ValueError(
                        "{} receieves the length of {}, "
                        "but the expected shape is {}".format(
                            key, len(state), len(param)
                        )
                    )
S
Steffy-zxf 已提交
1872
                else:
1873
                    match_keys.add(key)
S
Steffy-zxf 已提交
1874 1875
                    return param, state
            else:
1876 1877 1878 1879 1880
                state_shape = (
                    state.shape()
                    if inspect.ismethod(state.shape)
                    else state.shape
                )
S
Steffy-zxf 已提交
1881 1882

                if list(state_shape) != list(param.shape):
1883
                    missing_keys.append(key)
S
Steffy-zxf 已提交
1884
                    raise ValueError(
1885 1886 1887 1888
                        "{} receives a shape {}, but the expected shape is {}.".format(
                            key, list(state_shape), list(param.shape)
                        )
                    )
1889
                match_keys.add(key)
S
Steffy-zxf 已提交
1890
                return param, state
1891 1892

        matched_param_state = []
S
sneaxiy 已提交
1893
        for key, param in self._state_dict_impl(use_hook=False).items():
1894 1895 1896 1897 1898
            key_name = key if use_structured_name else param.name
            try:
                match_res = _check_match(key_name, param)
                matched_param_state.append(match_res)
            except ValueError as err:
1899
                warnings.warn("Skip loading for {}. ".format(key) + str(err))
1900 1901 1902
        for key in state_dict.keys():
            if key not in match_keys:
                unexpected_keys.append(key)
姜永久 已提交
1903
        if in_dygraph_mode():
1904 1905 1906
            for param, state in matched_param_state:
                param.set_value(state)
        else:
H
hong 已提交
1907

1908 1909 1910 1911 1912 1913 1914
            def _set_var(var, ndarray):
                t = global_scope().find_var(var.name).get_tensor()
                p = t._place()
                if p.is_cpu_place():
                    place = core.CPUPlace()
                elif p.is_cuda_pinned_place():
                    place = core.CUDAPinnedPlace()
1915 1916 1917 1918
                elif p.is_xpu_place():
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.XPUPlace(p.xpu_device_id())
1919 1920 1921 1922 1923 1924
                else:
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.CUDAPlace(p.gpu_device_id())
                t.set(ndarray, place)

1925 1926 1927 1928 1929
            try:
                executor = Executor(_get_device())._default_executor
                # restore parameter states
                core._create_loaded_parameter(
                    [param for param, state in matched_param_state],
1930 1931 1932
                    global_scope(),
                    executor,
                )
1933 1934 1935 1936 1937 1938
                for param, state in matched_param_state:
                    _set_var(param, state)
            except ValueError as e:
                raise ValueError(
                    "This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
                )
1939

1940 1941
        return missing_keys, unexpected_keys

C
chentianyu03 已提交
1942 1943 1944 1945 1946
    def to(self, device=None, dtype=None, blocking=None):
        '''
        Cast the parameters and buffers of Layer by the give device, dtype and blocking.

        Parameters:
1947 1948 1949 1950
            device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
            If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None.

1951
            dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
C
chentianyu03 已提交
1952

1953
            blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
C
chentianyu03 已提交
1954
              asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
1955

C
chentianyu03 已提交
1956
        Returns:
1957
            self
C
chentianyu03 已提交
1958 1959 1960 1961

        Examples:
            .. code-block:: python

1962
                # required: skip
C
chentianyu03 已提交
1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987
                import paddle

                linear=paddle.nn.Linear(2, 2)
                linear.weight
                #Parameter containing:
                #Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(dtype='float64')
                linear.weight
                #Tenor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(device='cpu')
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CPUPlace, stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])
                linear.to(device=paddle.CUDAPinnedPlace(), blocking=False)
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CUDAPinnedPlace, stop_gradient=False,
                #       [[-0.04989364, -0.56889004],
                #        [ 0.33960250,  0.96878713]])
1988

1989
        '''
1990 1991 1992 1993 1994 1995 1996
        return self._to_impl(
            device=device,
            dtype=dtype,
            blocking=blocking,
            include_sublayers=True,
            floating_only=False,
        )
1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009

    def _apply(self, func, device, dtype, blocking, include_sublayers=True):
        if include_sublayers:
            for layer in self.children():
                layer._apply(func, device, dtype, blocking, include_sublayers)

        for key, param in self._parameters.items():
            if param is not None:
                with no_grad():
                    param_applied = func(param, device, dtype, blocking)

                if param.grad is not None:
                    with no_grad():
2010 2011 2012
                        grad_applied = func(
                            param._grad_ivar(), device, dtype, blocking
                        )
2013 2014

        for key, buf in self._buffers.items():
2015 2016
            if buf is not None:
                self._buffers[key] = func(buf, device, dtype, blocking)
2017

2018 2019
        self._dtype = dtype

2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035
    def _transform(self, t, device, dtype, blocking):
        if device is None:
            device = t.place
        if dtype is None:
            dtype = t.dtype

        if type(dtype) is not VarDesc.VarType:
            dtype = convert_np_dtype_to_dtype_(dtype)

        # 1. gpu place need to determine whether the memory is sufficient for allocation:
        if t.place.is_gpu_place():
            # for gpu, minimum memory allocation unit is 256 bytes.
            size_dtype = core.size_of_dtype(dtype)
            # Note(zhangbo): Paddle GPU minimum memory allocation unit is 256 bytes, waiting_alloc_memory will comput ‘t’ occupied memory space.
            # Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
            waiting_alloc_memory = (
2036 2037
                ((np.prod(t.shape) * size_dtype) / 256 + 1) * 256 * 1.2
            )
2038 2039 2040
            gpu_memory_available = core.gpu_memory_available()
            if gpu_memory_available < waiting_alloc_memory:
                # Copy param / Tensor to cpu
2041 2042 2043
                t_used = t._copy_to(
                    paddle.CPUPlace(), blocking
                )  # k-v type will error
2044 2045 2046 2047 2048 2049 2050 2051 2052 2053
                # Release mem of t
                t.value().get_tensor()._clear()
            else:
                t_used = t
        else:
            t_used = t

        # 2. cast param / Tensor to dtype
        if dtype is not None and dtype != t_used.dtype:
            with paddle.fluid.framework._dygraph_place_guard(
2054 2055
                place=t_used.place
            ):
2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072
                t_casted = t_used.cast(dtype=dtype)
        else:
            t_casted = t_used

        # 3. Copy casted cpu param / Tensor to device
        if device is not None and not t_casted.place._equals(device):
            new_t = t_casted._copy_to(device, blocking)
        else:
            new_t = t_casted

        # 4. share Tensor to origin param / Tensor
        dst_tensor = t.value().get_tensor()
        src_tensor = new_t.value().get_tensor()
        dst_tensor._share_data_with(src_tensor)

        return t

2073 2074 2075 2076 2077 2078 2079 2080
    def _to_impl(
        self,
        device=None,
        dtype=None,
        blocking=None,
        include_sublayers=True,
        floating_only=False,
    ):
2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092
        '''
        Cast the parameters and buffers of Layer by the give device, dtype and blocking.

        Parameters:
            device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
            If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
            index of the GPUs or XPUs. Default: None.

            dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.

            blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
              asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
2093

2094 2095
            include_sublayers(bool|True, optional): If True, deal with self and all sublayers parameters and buffers, if not only deal with self parameters and buffers. Default: True.

2096 2097
            floating_only(bool|False, optional): If True, only cast all floating point parameters and buffers of Layer by the give device, dtype and blocking.

2098 2099
        Returns:
            self
C
chentianyu03 已提交
2100 2101 2102 2103

        '''

        if device is None and dtype is None and blocking is None:
2104
            return self
C
chentianyu03 已提交
2105 2106 2107 2108

        if device is not None:
            if isinstance(device, str):
                device = paddle.device._convert_to_place(device)
2109 2110 2111 2112 2113 2114 2115 2116 2117
            elif isinstance(
                device,
                (
                    core.CPUPlace,
                    core.CUDAPlace,
                    core.CUDAPinnedPlace,
                    core.XPUPlace,
                ),
            ):
C
chentianyu03 已提交
2118 2119 2120 2121
                pass
            else:
                raise ValueError(
                    "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
2122 2123
                    + type(device).__name__
                )
C
chentianyu03 已提交
2124 2125 2126 2127 2128

        if blocking is None:
            blocking = True
        else:
            assert isinstance(
2129 2130
                blocking, bool
            ), "blocking value error, must be the True, False or None"
C
chentianyu03 已提交
2131 2132

        def transform(t, device, dtype, blocking):
2133 2134 2135
            if floating_only and (not paddle.is_floating_point(t)):
                return t
            return self._transform(t, device, dtype, blocking)
C
chentianyu03 已提交
2136

2137 2138
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
2139
            self._apply(transform, device, dtype, blocking, include_sublayers)
2140

2141
        self._dtype = dtype
2142
        return self
C
chentianyu03 已提交
2143

2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155
    def _startup_program(self):
        """
        Return starup program containing initialization operations of all parameters.

        NOTE(dev): This is a very low level API and only for inner developer.
        """
        startup_program = Program()
        for param in self.parameters():
            param._create_init_op(startup_program.global_block())

        return startup_program

2156 2157 2158
    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict