layers.py 54.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
Xin Pan 已提交
15
import collections
16 17 18
import contextlib
import sys
import numpy as np
19
import six
20
import re
21 22 23
import copy
import weakref
import warnings
24
from copy import deepcopy
25 26
import inspect

27
import paddle
28

C
chengduo 已提交
29
from . import parallel_helper
X
Xin Pan 已提交
30
from .. import unique_name
31
from paddle.fluid import core
32
from .layer_object_helper import LayerObjectHelper
33
from .base import program_desc_tracing_guard, param_guard
34
from paddle.fluid import framework
35
from ..param_attr import ParamAttr
36 37 38
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _current_expected_place as _get_device
C
chentianyu03 已提交
39
from paddle.fluid.dygraph import no_grad
W
wanghuancoder 已提交
40
import paddle.utils.deprecated as deprecated
41

42
__all__ = ['Layer']
43

44 45 46 47 48 49 50 51
_first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re = re.compile('([a-z])([A-Z])')


def _convert_camel_to_snake(name):
    s1 = _first_cap_re.sub(r'\1_\2', name)
    return _all_cap_re.sub(r'\1_\2', s1).lower()

52

53 54 55 56 57 58 59 60 61 62 63
def _addindent(string, indent):
    s1 = string.split('\n')
    if len(s1) == 1:
        return string
    s2 = []
    for idx, line in enumerate(s1):
        if idx > 0:
            s2.append(str((indent * ' ') + line))
    return s1[0] + '\n' + '\n'.join(s2)


64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
class HookRemoveHelper(object):
    """ A HookRemoveHelper that can be used to remove hook. """

    next_hook_id = 0

    def __init__(self, hooks):
        self._hooks_ref = weakref.ref(hooks)
        self._hook_id = HookRemoveHelper.next_hook_id
        HookRemoveHelper.next_hook_id += 1

    def remove(self):
        hooks = self._hooks_ref()
        if hooks is not None and self._hook_id in hooks:
            del hooks[self._hook_id]


X
Xin Pan 已提交
80
class Layer(core.Layer):
81 82
    """
    Dynamic graph Layer based on OOD, includes the parameters of the layer, the structure of the forward graph and so on.
X
Xin Pan 已提交
83

84
    Parameters:
85 86
        name_scope (str, optional): prefix name used by the layer to name parameters.
            If prefix is "my_layer", parameter name in MyLayer
87 88 89
            can be "my_layer_0.w_n", where "w" is the parameter
            base name and "n" is an unique suffix auto-generated.
            If None, prefix name will be snake cased class name. Default: None.
90
        dtype(str, optional): data type of this parameter.
91 92
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
93
                Default: "float32"
94 95 96
    
    Returns:
        None
X
Xin Pan 已提交
97
    """
X
Xin Pan 已提交
98

99
    def __init__(self, name_scope=None, dtype="float32"):
100
        self.training = True
101
        if name_scope is None:
102 103
            name_scope = _convert_camel_to_snake(self.__class__.__name__)
        self._full_name = unique_name.generate(name_scope)
104
        self._helper = LayerObjectHelper(self._full_name)
X
Xin Pan 已提交
105
        self._built = False
M
minqiyang 已提交
106
        self._dtype = dtype
107
        self._init_in_dynamic_mode = framework.in_dygraph_mode()
108

X
Xin Pan 已提交
109
        self._parameters = collections.OrderedDict()
110 111 112
        # Buffers the variable (not parameter) created in layer
        self._buffers = collections.OrderedDict()
        self._non_persistable_buffer_names_set = set()
X
Xin Pan 已提交
113
        self._sub_layers = collections.OrderedDict()
L
lujun 已提交
114
        self._loaddict_holder = collections.OrderedDict()
115

116 117 118
        self._forward_pre_hooks = collections.OrderedDict()
        self._forward_post_hooks = collections.OrderedDict()

M
minqiyang 已提交
119
    def train(self):
120 121 122 123 124 125
        """
        Sets this Layer and all its sublayers to training mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

        Example::
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                mylayer.train()  # set mylayer._dropout to train mode
                out = mylayer(x)

150
        """
151 152 153 154 155
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
        if in_dygraph_mode():
            framework._dygraph_tracer().train_mode()
156 157 158
        # Layer-level setting
        self.training = True
        for layer in self.sublayers():
159
            layer.training = True
M
minqiyang 已提交
160 161

    def eval(self):
162 163 164 165 166 167
        """
        Sets this Layer and all its sublayers to evaluation mode.
        This only effects certain modules like `Dropout` and `BatchNorm`.

        Returns:
            None
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190

        Example::
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                x = paddle.randn([10, 1], 'float32')
                mylayer = MyLayer()
                mylayer.eval()  # set mylayer._dropout to eval mode
                out = mylayer(x)
                print(out)

191
        """
192 193 194 195 196
        # global setting in dygraph
        # NOTE(chenweihang): nn.Layer also can be used in static mode,
        # but _dygraph_tracer() can not be called in static mode
        if in_dygraph_mode():
            framework._dygraph_tracer().eval_mode()
197 198 199
        # Layer-level setting
        self.training = False
        for layer in self.sublayers():
200
            layer.training = False
M
minqiyang 已提交
201

L
LielinJiang 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
    def apply(self, fn):
        """
        Applies ``fn`` recursively to every sublayer (as returned by ``.sublayers()``)
        as well as self. Typical use includes initializing the parameters of a model.

        Parameters:
            fn (function): a function to be applied to each sublayer

        Returns:
            Layer: self

        Example::
            .. code-block:: python

              import paddle
              import paddle.nn as nn
218

L
LielinJiang 已提交
219 220 221 222 223
              net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

              def init_weights(layer):
                  if type(layer) == nn.Linear:
                      print('before init weight:', layer.weight.numpy())
224
                      new_weight = paddle.full(shape=layer.weight.shape, dtype=layer.weight.dtype, fill_value=0.9)
L
LielinJiang 已提交
225 226 227 228 229 230 231
                      layer.weight.set_value(new_weight)
                      print('after init weight:', layer.weight.numpy())

              net.apply(init_weights)

              print(net.state_dict())
        """
232
        for layer in self.children():
L
LielinJiang 已提交
233 234 235 236 237 238
            layer.apply(fn)

        fn(self)

        return self

X
Xin Pan 已提交
239
    def full_name(self):
240
        """Full name for this layer, composed by name_scope + "/" + MyLayer.__class__.__name__
X
Xin Pan 已提交
241

242 243
        Returns:
            str: full name of this layer.
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

        Example::
            .. code-block:: python

                import paddle

                class LinearNet(paddle.nn.Layer):
                    def __init__(self):
                        super(LinearNet, self).__init__(name_scope = "demo_linear_net")
                        self._linear = paddle.nn.Linear(1, 1)

                    def forward(self, x):
                        return self._linear(x)

                linear_net = LinearNet()
                print(linear_net.full_name())   # demo_linear_net_0

X
Xin Pan 已提交
261 262 263
        """
        return self._full_name

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
    def register_forward_post_hook(self, hook):
        """Register a forward post-hook for Layer. The hook will be called after `forward` function has been computed.

        It should have the following form, `input` and `output` of the `hook` is `input` and `output` of the `Layer` respectively.
        User can use forward post-hook to change the output of the Layer or perform information statistics tasks on the Layer.
 
        hook(Layer, input, output) -> None or modified output

        Parameters:
            hook(function): a function registered as a forward post-hook

        Returns:
            HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .

        Examples:
            .. code-block:: python

281 282 283 284 285 286
                import paddle
                import numpy as np

                # the forward_post_hook change the output of the layer: output = output * 2
                def forward_post_hook(layer, input, output):
                    # user can use layer, input and output for information statistis tasks
287

288 289
                    # change the output
                    return output * 2
290

291
                linear = paddle.nn.Linear(13, 5)
292

293 294
                # register the hook
                forward_post_hook_handle = linear.register_forward_post_hook(forward_post_hook)
295

296 297
                value1 = np.arange(26).reshape(2, 13).astype("float32")
                in1 = paddle.to_tensor(value1)
298

299
                out0 = linear(in1)
300

301 302 303 304 305 306 307
                # remove the hook
                forward_post_hook_handle.remove()

                out1 = linear(in1)

                # hook change the linear's output to output * 2, so out0 is equal to out1 * 2.
                assert (out0.numpy() == (out1.numpy()) * 2).any()
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
        """
        hook_remove_helper = HookRemoveHelper(self._forward_post_hooks)
        self._forward_post_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

    def register_forward_pre_hook(self, hook):
        """Register a forward pre-hook for Layer. The hook will be called before `forward` function has been computed.
        
        It should have the following form, `input` of the `hook` is `input` of the `Layer`,
        hook can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if 
        a single value is returned(unless that value is already a tuple).
        User can use forward pre-hook to change the input of the Layer or perform information statistics tasks on the Layer.

        hook(Layer, input) -> None or modified input

        Parameters:
            hook(function): a function registered as a forward pre-hook

        Returns:
            HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()` .

        Examples:
            .. code-block:: python

332 333
                import paddle
                import numpy as np
334

335 336 337
                # the forward_post_hook change the input of the layer: input = input * 2
                def forward_pre_hook(layer, input):
                    # user can use layer and input for information statistis tasks
338

339 340 341
                    # change the input
                    input_return = (input[0] * 2)
                    return input_return
342

343
                linear = paddle.nn.Linear(13, 5)
344

345 346
                # register the hook
                forward_pre_hook_handle = linear.register_forward_pre_hook(forward_pre_hook)
347

348 349 350
                value0 = np.arange(26).reshape(2, 13).astype("float32")
                in0 = paddle.to_tensor(value0)
                out0 = linear(in0)
351

352 353
                # remove the hook
                forward_pre_hook_handle.remove()
354

355 356 357
                value1 = value0 * 2
                in1 = paddle.to_tensor(value1)
                out1 = linear(in1)
358

359 360
                # hook change the linear's input to input * 2, so out0 is equal to out1.
                assert (out0.numpy() == out1.numpy()).any()
361 362 363 364 365
        """
        hook_remove_helper = HookRemoveHelper(self._forward_pre_hooks)
        self._forward_pre_hooks[hook_remove_helper._hook_id] = hook
        return hook_remove_helper

366 367
    def create_parameter(self,
                         shape,
368
                         attr=None,
369
                         dtype=None,
370 371
                         is_bias=False,
                         default_initializer=None):
372 373 374
        """Create parameters for this layer.
        
        Parameters:
375
            shape(list): Shape of the parameter.
376 377
            attr(ParamAttr, optional): Parameter attribute of weight. Please refer to :ref:`api_paddle_ParamAttr`. Default: None.
            dtype(str, optional): Data type of this parameter.
378
                If set str, it can be "bool",  "float16", "float32", "float64",
379 380
                "int8", "int16", "int32", "int64", "uint8" or "uint16". Default: "float32".
            is_bias(bool, optional): if this is a bias parameter. Default: False.
381
            default_initializer(Initializer, optional): the default initializer for this parameter.
382
                If set None, default initializer will be set to paddle.nn.initializer.Xavier and paddle.nn.initializer.Constant
383
                for non-bias and bias parameter, respectively. Default: None.
384

385
        Returns:
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
            :Tensor, created parameter.

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

407
        """
H
hong 已提交
408 409 410 411
        temp_attr = copy.deepcopy(attr)
        if isinstance(temp_attr, six.string_types) and temp_attr == "":
            temp_attr = None
        return self._helper.create_parameter(temp_attr, shape, dtype, is_bias,
412 413
                                             default_initializer)

W
wanghuancoder 已提交
414 415 416 417
    @deprecated(
        since="2.0.0",
        update_to="paddle.nn.Layer.create_tensor",
        reason="New api in create_tensor, easier to use.")
418
    def create_variable(self, name=None, persistable=None, dtype=None):
W
wanghuancoder 已提交
419 420 421
        """

        Create Tensor for this layer.
422

423
        Parameters:
W
wanghuancoder 已提交
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None

            persistable(bool, optional): if set this tensor persistable. Default: False

            dtype(str, optional): data type of this parameter. If set str, it can be "bool", "float16", "float32", "float64","int8", "int16", "int32", "int64", "uint8" or "uint16". If set None, it will be "float32". Default: None

        Returns:
            Tensor, created Tensor.

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
                        super(MyLinear, self).__init__()
                        self.linear = paddle.nn.Linear( 10, 10)
                            
                        self.back_var = self.create_variable(name = "linear_tmp_0", dtype=self._dtype)
                    
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
                        
                        return out

        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
            var_name = unique_name.generate(".".join(
                [self._full_name, "_generated_var"]))

        return self._helper.main_program.current_block().create_var(
            name=var_name,
            persistable=persistable,
            dtype=dtype,
            type=core.VarDesc.VarType.LOD_TENSOR)

    # TODO: Add more parameter list when we need them
    def create_tensor(self, name=None, persistable=None, dtype=None):
        """

        Create Tensor for this layer.

        Parameters:
            name(str, optional): name of the tensor. Please refer to :ref:`api_guide_Name` . Default: None
            persistable(bool, optional): if set this tensor persistable. Default: False
475
            dtype(str, optional): data type of this parameter.
476 477
                If set str, it can be "bool",  "float16", "float32", "float64",
                "int8", "int16", "int32", "int64", "uint8" or "uint16".
478
                If set None, it will be "float32". Default: None
479

480
        Returns:
W
wanghuancoder 已提交
481
            Tensor, created Tensor.
482 483 484 485 486 487 488 489 490 491 492 493 494

        Examples:
            .. code-block:: python

                import paddle

                class MyLinear(paddle.nn.Layer):
                    def __init__(self,
                                in_features,
                                out_features):
                        super(MyLinear, self).__init__()
                        self.linear = paddle.nn.Linear( 10, 10)
                            
W
wanghuancoder 已提交
495
                        self.back_var = self.create_tensor(name = "linear_tmp_0", dtype=self._dtype)
496 497 498 499 500 501 502
                    
                    def forward(self, input):
                        out = self.linear(input)
                        paddle.assign( out, self.back_var)
                        
                        return out

503 504 505 506 507 508 509 510
        """
        if name is not None:
            var_name = ".".join([self._full_name, name])
        else:
            var_name = unique_name.generate(".".join(
                [self._full_name, "_generated_var"]))

        return self._helper.main_program.current_block().create_var(
511 512 513 514
            name=var_name,
            persistable=persistable,
            dtype=dtype,
            type=core.VarDesc.VarType.LOD_TENSOR)
515

X
polish  
Xin Pan 已提交
516
    def parameters(self, include_sublayers=True):
517
        """Returns a list of all Parameters from current layer and its sub-layers.
X
Xin Pan 已提交
518

519
        Returns:
520 521 522 523 524 525 526 527 528 529
            list of Tensor : a list of Parameters.

        Examples:
            .. code-block:: python

            import paddle

            linear = paddle.nn.Linear(1,1)
            print(linear.parameters())  # print linear_0.w_0 and linear_0.b_0

X
Xin Pan 已提交
530
        """
531 532 533 534 535
        ret = [
            param
            for _, param in self.named_parameters(
                include_sublayers=include_sublayers)
        ]
X
polish  
Xin Pan 已提交
536
        return ret
X
Xin Pan 已提交
537

538 539 540 541 542 543 544 545 546
    def children(self):
        """Returns an iterator over immediate children layers.

        Yields:
            Layer: a child layer

        Examples:
            .. code-block:: python

547
                import paddle
548

549 550 551 552 553
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)

                layer_list = list(model.children())
554

555
                print(layer_list)   # [<paddle.nn.layer.common.Linear object at 0x7f7b8113f830>, <paddle.nn.layer.common.Linear object at 0x7f7b8113f950>]
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570

        """
        for _, layer in self.named_children():
            yield layer

    def named_children(self):
        """Returns an iterator over immediate children layers, yielding both
        the name of the layer as well as the layer itself.

        Yields:
            (string, Layer): Tuple containing a name and child layer

        Examples:
            .. code-block:: python

571
                import paddle
572

573 574 575 576 577 578 579
                linear1 = paddle.nn.Linear(10, 3)
                linear2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(linear1, linear2)
                for prefix, layer in model.named_children():
                    print(prefix, layer)
                    # ('0', <paddle.nn.layer.common.Linear object at 0x7fb61ed85830>)
                    # ('1', <paddle.nn.layer.common.Linear object at 0x7fb61ed85950>)
580 581 582 583 584 585 586 587

        """
        memo = set()
        for name, layer in self._sub_layers.items():
            if layer is not None and layer not in memo:
                memo.add(layer)
                yield name, layer

J
Jiabin Yang 已提交
588
    def sublayers(self, include_self=False):
X
Xin Pan 已提交
589 590
        """Returns a list of sub layers.

591
        Parameters:
J
Jiabin Yang 已提交
592
            include_self(bool, optional): Whether return self as sublayers. Default: False
X
Xin Pan 已提交
593

594 595
        Returns:
            list of Layer : a list of sub layers.
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615

        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self._linear = paddle.nn.Linear(1, 1)
                        self._dropout = paddle.nn.Dropout(p=0.5)

                    def forward(self, input):
                        temp = self._linear(input)
                        temp = self._dropout(temp)
                        return temp

                mylayer = MyLayer()
                print(mylayer.sublayers())  # [<paddle.nn.layer.common.Linear object at 0x7f44b58977d0>, <paddle.nn.layer.common.Dropout object at 0x7f44b58978f0>]

X
Xin Pan 已提交
616
        """
617 618
        ret = [
            layer
J
Jiabin Yang 已提交
619
            for _, layer in self.named_sublayers(include_self=include_self)
620
        ]
X
Xin Pan 已提交
621 622
        return ret

623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
    def named_parameters(self, prefix='', include_sublayers=True):
        """
        Returns an iterator over all parameters in the Layer, yielding tuple of name and parameter.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_sublayers(bool, optional): Whether include the parameters of sublayers.
                If True, also include the named parameters from sublayers. Default: True.

        Yields:
            (string, Parameter): Tuple of name and Parameter

        Examples:
            .. code-block:: python

638
                import paddle
639

640 641 642 643 644
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for name, param in model.named_parameters():
                    print(name, param)
645 646 647 648 649

        """
        params_set = set()
        named_sublayers = self.named_sublayers(
            prefix=prefix,
J
Jiabin Yang 已提交
650
            include_self=True) if include_sublayers else zip([prefix], [self])
651 652 653 654 655 656 657 658 659
        for layer_prefix, sublayer in named_sublayers:
            params = sublayer._parameters.items()
            for key, param in params:
                if param is None or param in params_set:
                    continue
                params_set.add(param)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, param

J
Jiabin Yang 已提交
660
    def named_sublayers(self, prefix='', include_self=False, layers_set=None):
661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
        """
        Returns an iterator over all sublayers in the Layer, yielding tuple of name and sublayer.
        The duplicate sublayer will only be yielded once.

        Parameters:
            prefix(str, optional): Prefix to prepend to all parameter names. Default: ''.
            include_self(bool, optional): Whether include the Layer itself. Default: False.
            layers_set(set, optioanl): The set to record duplicate sublayers. Default: None.

        Yields:
            (string, Layer): Tuple of name and Layer

        Examples:
            .. code-block:: python

676
                import paddle
677

678 679 680 681 682
                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = paddle.nn.Sequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
683 684 685 686 687 688 689

        """
        if layers_set is None:
            layers_set = set()
        if include_self and self not in layers_set:
            layers_set.add(self)
            yield prefix, self
J
Jiabin Yang 已提交
690 691 692 693 694 695 696 697
        for key, layer in self._sub_layers.items():
            if layer is None:
                continue
            layer_prefix = prefix + ('.' if prefix else '') + key
            for p, l in layer.named_sublayers(
                    prefix=layer_prefix, include_self=True,
                    layers_set=layers_set):
                yield p, l
698

699
    def register_buffer(self, name, tensor, persistable=True):
700
        """
701
        Registers a tensor as buffer into the layer.
702

703
        `buffer` is a non-trainable tensor and will not be updated by optimizer,
704 705 706 707 708 709 710 711 712 713
        but is necessary for evaluation and inference. For example, the mean and variance in BatchNorm layers.
        The registered buffer is persistable by default, and will be saved into
        `state_dict` alongside parameters. If set persistable=False, it registers
        a non-persistable buffer, so that it will not be a part of `state_dict` .

        Buffers can be accessed as attributes using given names.

        Parameters:
            name (string): name of the buffer. The buffer can be accessed
                from this layer using the given name
714
            tensor (Tensor): the tensor to be registered as buffer.
715 716 717 718 719 720 721 722 723 724
            persistable (bool): whether the buffer is part of this layer's
                state_dict.

        Returns:
            None
        
        Examples:
            .. code-block:: python

                import numpy as np
725
                import paddle
726

727 728 729 730 731 732 733
                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                # get the buffer by attribute.
                print(linear.buf_name)
734 735 736 737 738 739 740 741 742 743 744

        """

        if '_buffers' not in self.__dict__:
            raise ValueError(
                "super(YourLayer, self).__init__() should be called first")
        elif not isinstance(name, six.string_types):
            raise TypeError(
                "The name of buffer should be a string, but received {}.".
                format(type(name).__name__))
        elif '.' in name:
745 746 747 748
            raise KeyError(
                "The name of buffer can not contain `.`, "
                "because when you access the newly added buffer in the "
                "form of `self.**.**`, it will cause AttributeError.")
749 750 751 752
        elif name == '':
            raise KeyError("The name of buffer can not be empty.")
        elif hasattr(self, name) and name not in self._buffers:
            raise KeyError("attribute '{}' already exists.".format(name))
753
        elif tensor is not None and not type(tensor) == core.VarBase:
754 755
            raise TypeError(
                "The registered buffer should be a core.VarBase, but received {}.".
756
                format(type(tensor).__name__))
757
        else:
758
            self._buffers[name] = tensor
759 760 761 762 763 764 765 766 767 768 769 770 771
            if persistable:
                self._non_persistable_buffer_names_set.discard(name)
            else:
                self._non_persistable_buffer_names_set.add(name)

    def buffers(self, include_sublayers=True):
        """
        Returns a list of all buffers from current layer and its sub-layers.

        Parameters:
            include_sublayers(bool, optional): Whether include the buffers of sublayers. If True, also include the buffers from sublayers. Default: True

        Returns:
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
            list of Tensor : a list of buffers.

        Examples:
            .. code-block:: python

                import numpy as np
                import paddle

                linear = paddle.nn.Linear(10, 3)
                value = np.array([0]).astype("float32")
                buffer = paddle.to_tensor(value)
                linear.register_buffer("buf_name", buffer, persistable=True)

                print(linear.buffers())     # == print([linear.buf_name])

787 788 789 790 791 792 793 794 795 796
        """
        ret = [
            buffer
            for _, buffer in self.named_buffers(
                include_sublayers=include_sublayers)
        ]
        return ret

    def named_buffers(self, prefix='', include_sublayers=True):
        """
797
        Returns an iterator over all buffers in the Layer, yielding tuple of name and Tensor.
798 799 800 801 802 803 804

        Parameters:
            prefix(str, optional): Prefix to prepend to all buffer names. Default: ''.
            include_sublayers(bool, optional): Whether include the buffers of sublayers.
                If True, also include the named buffers from sublayers. Default: True.

        Yields:
805
            (string, Tensor): Tuple of name and tensor
806 807 808 809 810

        Examples:
            .. code-block:: python

                import numpy as np
811
                import paddle
812

813 814 815 816
                fc1 = paddle.nn.Linear(10, 3)
                buffer1 = paddle.to_tensor(np.array([0]).astype("float32"))
                # register a tensor as buffer by specific `persistable`
                fc1.register_buffer("buf_name_1", buffer1, persistable=True)
817

818 819 820 821 822
                fc2 = paddle.nn.Linear(3, 10)
                buffer2 = paddle.to_tensor(np.array([1]).astype("float32"))
                # register a buffer by assigning an attribute with Tensor.
                # The `persistable` can only be False by this way.
                fc2.buf_name_2 = buffer2
823

824
                model = paddle.nn.Sequential(fc1, fc2)
825

826 827 828
                # get all named buffers
                for name, buffer in model.named_buffers():
                    print(name, buffer)
829 830 831 832 833

        """
        buffers_set = set()
        named_sublayers = self.named_sublayers(
            prefix=prefix,
J
Jiabin Yang 已提交
834
            include_self=True) if include_sublayers else zip([prefix], [self])
835 836 837 838 839 840 841 842 843
        for layer_prefix, sublayer in named_sublayers:
            buffers = sublayer._buffers.items()
            for key, buffer in buffers:
                if buffer is None or buffer in buffers_set:
                    continue
                buffers_set.add(buffer)
                name = layer_prefix + ('.' if layer_prefix else '') + key
                yield name, buffer

X
Xin Pan 已提交
844
    def clear_gradients(self):
845 846 847 848 849 850 851 852 853
        """
        Clear the gradients of all parameters for this layer.
        
        Returns:
            None
        
        Examples:
            .. code-block:: python

854
                import paddle
855 856
                import numpy as np

857 858 859 860 861 862 863 864 865
                value = np.arange(26).reshape(2, 13).astype("float32")
                a = paddle.to_tensor(value)
                linear = paddle.nn.Linear(13, 5)
                adam = paddle.optimizer.Adam(learning_rate=0.01,
                                            parameters=linear.parameters())
                out = linear(a)
                out.backward()
                adam.step()
                linear.clear_gradients()
866 867

        """
X
Xin Pan 已提交
868
        for p in self.parameters():
869 870
            if p.trainable:
                p.clear_gradient()
X
Xin Pan 已提交
871

872
    def _build_once(self, *args, **kwargs):
873 874
        pass

875
    def __call__(self, *inputs, **kwargs):
876
        with param_guard(self._parameters), param_guard(self._buffers):
877 878 879 880 881 882 883 884 885 886
            for forward_pre_hook in self._forward_pre_hooks.values():
                hook_result = forward_pre_hook(self, inputs)
                if hook_result is not None:
                    if not isinstance(hook_result, tuple):
                        hook_result = (hook_result, )
                    inputs = hook_result

            if not self._built:
                with program_desc_tracing_guard(False):
                    self._build_once(*inputs, **kwargs)
887 888 889 890 891 892

                    # TODO(liuyuhui) Only xpu broadcast parameters here. 
                    # The other device is to call _sync_params_buffers in DataParallel 
                    # to realize the parameter synchronization among multiply cards.
                    if parallel_helper._is_data_parallel_mode(
                    ) and paddle.is_compiled_with_xpu():
893 894
                        parallel_helper._broadcast_parameters(
                            self._parameters.values())
895

896 897
                self._built = True

898
            outputs = self.forward(*inputs, **kwargs)
899

900 901 902 903
            for forward_post_hook in self._forward_post_hooks.values():
                hook_result = forward_post_hook(self, inputs, outputs)
                if hook_result is not None:
                    outputs = hook_result
904

905
            return outputs
M
minqiyang 已提交
906

907
    def forward(self, *inputs, **kwargs):
908 909 910 911 912 913 914 915
        """
        Defines the computation performed at every call.
        Should be overridden by all subclasses.

        Parameters:
            *inputs(tuple): unpacked tuple arguments
            **kwargs(dict): unpacked dict arguments
        """
916
        raise NotImplementedError
X
Xin Pan 已提交
917 918 919 920

    def backward(self, *inputs):
        raise ValueError("Layer shouldn't implement backward")

X
Xin Pan 已提交
921 922 923
    def add_sublayer(self, name, sublayer):
        """Adds a sub Layer instance.

924
        Added sublayer can be accessed by self.name
X
Xin Pan 已提交
925

926 927 928
        Parameters:
            name(str): name of this sublayer.
            sublayer(Layer): an instance of Layer.
X
Xin Pan 已提交
929
        Returns:
930
            Layer: the sublayer passed in.
931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956
        
        Examples:
            .. code-block:: python

                import paddle

                class MySequential(paddle.nn.Layer):
                    def __init__(self, *layers):
                        super(MySequential, self).__init__()
                        if len(layers) > 0 and isinstance(layers[0], tuple):
                            for name, layer in layers:
                                self.add_sublayer(name, layer)
                        else:
                            for idx, layer in enumerate(layers):
                                self.add_sublayer(str(idx), layer)

                    def forward(self, input):
                        for layer in self._sub_layers.values():
                            input = layer(input)
                        return input

                fc1 = paddle.nn.Linear(10, 3)
                fc2 = paddle.nn.Linear(3, 10, bias_attr=False)
                model = MySequential(fc1, fc2)
                for prefix, layer in model.named_sublayers():
                    print(prefix, layer)
X
Xin Pan 已提交
957
        """
958
        assert (isinstance(sublayer, core.Layer) or sublayer == None)
959

X
Xin Pan 已提交
960 961 962 963 964 965
        self._sub_layers[name] = sublayer
        return sublayer

    def add_parameter(self, name, parameter):
        """Adds a Parameter instance.

966
        Added parameter can be accessed by self.name
X
Xin Pan 已提交
967

968 969 970
        Parameters:
            name(str): name of this sublayer.
            parameter(Parameter): an instance of Parameter.
X
Xin Pan 已提交
971
        Returns:
972
            Parameter: the parameter passed in.
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991
        Examples:
            .. code-block:: python

                import paddle

                class MyLayer(paddle.nn.Layer):
                    def __init__(self):
                        super(MyLayer, self).__init__()
                        self._linear = paddle.nn.Linear(1, 1)
                        w_tmp = self.create_parameter([1,1])
                        self.add_parameter("w_tmp", w_tmp)

                    def forward(self, input):
                        return self._linear(input)

                mylayer = MyLayer()
                for name, param in mylayer.named_parameters():
                    print(name, param)      # will print w_tmp,_linear.weight,_linear.bias

X
Xin Pan 已提交
992
        """
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
        if '_parameters' not in self.__dict__:
            raise RuntimeError(
                "super(YourLayer, self).__init__() should be called firstly.")
        elif not isinstance(name, six.string_types):
            raise TypeError(
                "The name of parameter should be a string, but received {}.".
                format(type(name).__name__))
        elif '.' in name:
            raise KeyError(
                "The name of parameter can not contain `.`, "
                "because when you access the newly added parameter in the "
                "form of `self.**.**`, it will cause AttributeError.")
        elif name == '':
            raise KeyError("The name of parameter can not be empty.")
        elif hasattr(self, name) and name not in self._parameters:
            raise KeyError("The parameter '{}' already exists.".format(name))
        elif parameter is not None and not isinstance(parameter,
                                                      framework.Parameter):
1011
            raise TypeError(
1012 1013 1014 1015 1016
                "The parameter to be added should be a Parameter, but received {}.".
                format(type(parameter).__name__))
        else:
            if parameter is None:
                self._parameters[name] = None
1017

1018 1019 1020
            if len(self._loaddict_holder) > 0:
                assert parameter.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in state_dict".format(
                    parameter.name)
H
hong 已提交
1021

1022
                parameter.set_value(self._loaddict_holder[parameter.name])
1023

1024
            self._parameters[name] = parameter
X
Xin Pan 已提交
1025 1026
        return parameter

1027 1028 1029 1030 1031 1032
    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, state):
        self.__dict__.update(state)

X
Xin Pan 已提交
1033
    def __getattr__(self, name):
1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
        if '_parameters' in self.__dict__:
            _parameters = self.__dict__['_parameters']
            if name in self._parameters:
                return self._parameters[name]
        if '_sub_layers' in self.__dict__:
            _sub_layers = self.__dict__['_sub_layers']
            if name in self._sub_layers:
                return self._sub_layers[name]
        if '_buffers' in self.__dict__:
            _buffers = self.__dict__['_buffers']
            if name in _buffers:
                return _buffers[name]
        return object.__getattribute__(self, name)
X
Xin Pan 已提交
1047 1048

    def __setattr__(self, name, value):
S
songyouwei 已提交
1049 1050 1051 1052 1053
        def _remove_if_exist(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

1054 1055
        if isinstance(getattr(type(self), name, None), property):
            object.__setattr__(self, name, value)
1056
        params = self.__dict__.get('_parameters', None)
X
Xin Pan 已提交
1057 1058 1059 1060
        if isinstance(value, framework.Parameter):
            if params is None:
                raise ValueError(
                    "super(YourLayer, self).__init__() should be called first")
H
hong 已提交
1061
            if len(self._loaddict_holder) > 0:
1062
                assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in state_dict".format(
H
hong 已提交
1063 1064 1065 1066
                    value.name)

                value.set_value(self._loaddict_holder[value.name])

1067
            _remove_if_exist(self.__dict__, self._buffers, self._sub_layers)
1068
            params[name] = value
1069 1070 1071 1072 1073 1074
        elif params is not None and name in params:
            if value is not None:
                raise TypeError(
                    "assignment to parameter '{}' should be of type Parameter or None, but got '{}'"
                    .format(name, type(value).__name__))
            params[name] = None
X
Xin Pan 已提交
1075
        else:
1076 1077 1078 1079 1080 1081 1082
            layers = self.__dict__.get('_sub_layers', None)
            if isinstance(value, core.Layer):
                if layers is None:
                    raise ValueError(
                        "super(YourLayer, self).__init__() should be called first"
                    )

1083
                _remove_if_exist(self.__dict__, self._parameters, self._buffers)
1084 1085 1086 1087 1088 1089 1090 1091
                layers[name] = value
            elif layers is not None and name in layers:
                if value is not None:
                    raise TypeError(
                        "assignment to sublayer '{}' should be of type Layer or None, but got '{}'"
                        .format(name, type(value).__name__))
                layers[name] = None
            else:
1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105
                _buffers = self.__dict__.get('_buffers', None)
                if type(value) == core.VarBase:
                    if _buffers is None:
                        raise ValueError(
                            "super(YourLayer, self).__init__() should be called first"
                        )
                    _remove_if_exist(self.__dict__, self._parameters,
                                     self._sub_layers)
                    # Set persistable=False by default. Only `register_buffer` can
                    # add a persistable buffer.
                    if name not in self._buffers:
                        self._non_persistable_buffer_names_set.add(name)
                    _buffers[name] = value
                elif _buffers is not None and name in _buffers:
1106 1107 1108 1109 1110
                    # Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in 
                    # decorated function, such as `self.buffer = new_tensor`. So we update its
                    # value via `assign`.
                    if type(value) == framework.Variable:
                        from paddle import assign
1111 1112 1113 1114 1115 1116 1117 1118 1119
                        # Note(zhhsplendid): the condition below happens in PaddleGan model,
                        # but should all non-Variable _buffers[name] be re-assign? We
                        # should consider it in the future. I current wrote this as
                        # conservative code.
                        if _buffers[name] is None or type(_buffers[
                                name]) == core.VarBase:
                            _buffers[name] = assign(value)
                        else:
                            assign(value, _buffers[name])
1120
                    elif value is not None:
1121 1122 1123
                        raise TypeError(
                            "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'"
                            .format(name, type(value).__name__))
1124 1125 1126 1127
                    else:
                        # Assigning None will remove the buffer, but if re-assign a new varBase to it,
                        # it will be remarked as a buffer with same `persistable` attribute.
                        _buffers[name] = None
1128 1129
                else:
                    object.__setattr__(self, name, value)
X
Xin Pan 已提交
1130 1131 1132 1133 1134 1135

    def __delattr__(self, name):
        if name in self._parameters:
            del self._parameters[name]
        elif name in self._sub_layers:
            del self._sub_layers[name]
1136 1137 1138
        elif name in self._buffers:
            del self._buffers[name]
            self._non_persistable_buffer_names_set.discard(name)
X
Xin Pan 已提交
1139 1140 1141
        else:
            object.__delattr__(self, name)

1142 1143
    def __dir__(self):
        """
W
wanghuancoder 已提交
1144
        Return a list. Get all parameters, buffers(non-parameter tensors), sublayers, method and attr of Layer.
1145 1146

        Examples:
1147 1148 1149
            .. code-block:: python
                import paddle
                import numpy as np
1150

1151 1152 1153 1154 1155
                class Mylayer(paddle.nn.Layer):
                    def __init__(self):
                        super(Mylayer, self).__init__()
                        self.linear1 = paddle.nn.Linear(10, 10)
                        self.linear2 = paddle.nn.Linear(5, 5)
C
cnn 已提交
1156
                        self.conv2d = paddle.nn.Conv2D(3, 2, 3)
1157 1158
                        self.embedding = paddle.nn.Embedding(128, 16)
                        self.h_0 = paddle.to_tensor(np.zeros([10, 10]).astype('float32'))
1159

1160 1161 1162 1163
                mylayer = Mylayer()
                print(dir(mylayer))
                # only parts are shown, because of list have too much content
                # ['__call__', '__class__',  ... , 'conv2d', 'embedding', 'h_0', 'linear1', 'linear2', ... , 'sublayers', 'train']
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175

        """
        method = dir(self.__class__)
        attrs = list(self.__dict__.keys())
        parameters = list(self._parameters.keys())
        sublayers = list(self._sub_layers.keys())
        buffers = list(self._buffers.keys())

        keys = method + attrs + parameters + sublayers + buffers

        return keys

1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
    def extra_repr(self):
        """
        Extra representation of this layer, you can have custom implementation
        of your own layer.
        """
        return ''

    def __repr__(self):
        extra_lines = []
        extra_repr = self.extra_repr()
        extra_lines = extra_repr.split('\n')
        sublayer_lines = []
        for name, layer in self._sub_layers.items():
            sublayer_str = repr(layer)
            sublayer_str = _addindent(sublayer_str, 2)
            sublayer_lines.append('(' + name + '): ' + sublayer_str)

        final_str = self.__class__.__name__ + '('
        if extra_lines:
            if len(extra_lines) > 1:
                final_str += '\n  ' + '\n  '.join(extra_lines) + '\n'
            elif len(extra_lines) == 1:
                final_str += extra_lines[0]
        if sublayer_lines:
            final_str += '\n  ' + '\n  '.join(sublayer_lines) + '\n'

        final_str += ')'
        return final_str

H
hong 已提交
1205 1206 1207 1208
    def state_dict(self,
                   destination=None,
                   include_sublayers=True,
                   structured_name_prefix=""):
H
hong 已提交
1209
        '''
1210
        Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
H
hong 已提交
1211

1212
        Parameters:
1213 1214
            destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
            include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
H
hong 已提交
1215 1216

        Retruns:
1217
            dict: a dict contains all the parameters and persistable buffers.
H
hong 已提交
1218 1219

        Examples:
1220 1221
            .. code-block:: python

1222
                import paddle
H
hong 已提交
1223

1224 1225 1226 1227
                emb = paddle.nn.Embedding(10, 10)

                state_dict = emb.state_dict()
                paddle.save( state_dict, "paddle_dy.pdparams")
H
hong 已提交
1228 1229 1230

        '''

1231 1232 1233 1234
        if destination is None:
            destination = collections.OrderedDict()
        for name, data in self._parameters.items():
            if data is not None:
H
hong 已提交
1235
                destination[structured_name_prefix + name] = data
1236 1237 1238
        for name, buffer in self._buffers.items():
            if buffer is not None and name not in self._non_persistable_buffer_names_set:
                destination[structured_name_prefix + name] = buffer
1239 1240 1241 1242 1243 1244

        if include_sublayers:
            for layer_name, layer_item in self._sub_layers.items():
                if layer_item is not None:
                    destination_temp = destination.copy()
                    destination_temp.update(
H
hong 已提交
1245 1246 1247
                        layer_item.state_dict(
                            destination_temp, include_sublayers,
                            structured_name_prefix + layer_name + "."))
1248 1249 1250
                    destination = destination_temp
        return destination

1251
    @framework.deprecate_stat_dict
J
Jiabin Yang 已提交
1252
    def set_state_dict(self, state_dict, use_structured_name=True):
H
hong 已提交
1253
        '''
1254
        Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
H
hong 已提交
1255

1256
        Parameters:
1257 1258
            state_dict(dict) : Dict contains all the parameters and persistable buffers.
            use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. 
H
hong 已提交
1259
                                                  Default: True
H
hong 已提交
1260 1261 1262 1263
        Returns:
            None

        Examples:
1264 1265
            .. code-block:: python

1266
                import paddle
1267

1268
                emb = paddle.nn.Embedding(10, 10)
H
hong 已提交
1269

1270
                state_dict = emb.state_dict()
1271 1272
                paddle.save(state_dict, "paddle_dy.pdparams")
                para_state_dict = paddle.load("paddle_dy.pdparams")
1273
                emb.set_state_dict(para_state_dict)
H
hong 已提交
1274

H
hong 已提交
1275 1276
        '''

1277 1278 1279 1280 1281
        def _check_match(key, param):
            state = state_dict.get(key, None)
            if state is None:
                raise ValueError("{} is not found in the provided dict.".format(
                    key))
1282 1283 1284
            state_shape = state.shape() if inspect.ismethod(
                state.shape) else state.shape
            if list(state_shape) != list(param.shape):
1285 1286
                raise ValueError(
                    "{} receives a shape {}, but the expected shape is {}.".
1287
                    format(key, list(state_shape), list(param.shape)))
1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302
            return param, state

        matched_param_state = []
        for key, param in self.state_dict().items():
            key_name = key if use_structured_name else param.name
            try:
                match_res = _check_match(key_name, param)
                matched_param_state.append(match_res)
            except ValueError as err:
                warnings.warn(("Skip loading for {}. ".format(key) + str(err)))

        if in_dygraph_mode():
            for param, state in matched_param_state:
                param.set_value(state)
        else:
H
hong 已提交
1303

1304 1305 1306 1307 1308 1309 1310
            def _set_var(var, ndarray):
                t = global_scope().find_var(var.name).get_tensor()
                p = t._place()
                if p.is_cpu_place():
                    place = core.CPUPlace()
                elif p.is_cuda_pinned_place():
                    place = core.CUDAPinnedPlace()
1311 1312 1313 1314
                elif p.is_xpu_place():
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.XPUPlace(p.xpu_device_id())
1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328
                else:
                    p = core.Place()
                    p.set_place(t._place())
                    place = core.CUDAPlace(p.gpu_device_id())
                t.set(ndarray, place)

            executor = Executor(_get_device())._default_executor
            # restore parameter states
            core._create_loaded_parameter(
                [param for param, state in matched_param_state],
                global_scope(), executor)
            for param, state in matched_param_state:
                _set_var(param, state)

C
chentianyu03 已提交
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436
    def _apply(self, func, device, dtype, blocking):
        for layer in self.children():
            layer._apply(func, device, dtype, blocking)

        for key, param in self._parameters.items():
            if param is not None:
                with no_grad():
                    param_applied = func(param, device, dtype, blocking)
                    assert param.is_leaf
                    param_applied.stop_gradient = param.stop_gradient
                    self._parameters[key] = param_applied

                if param.grad is not None:
                    with no_grad():
                        grad_applied = func(param._grad_ivar(), device, dtype,
                                            blocking)

                        grad_applied.stop_gradient = param._grad_ivar(
                        ).stop_gradient
                        self._parameters[key]._set_grad_ivar(grad_applied)

        for key, buf in self._buffers.items():
            self._buffers[key] = func(buf, device, dtype, blocking)

    def to(self, device=None, dtype=None, blocking=None):
        '''
        Cast the parameters and buffers of Layer by the give device, dtype and blocking.

        Parameters:
            device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored. 
            If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the 
            index of the GPUs or XPUs. Default: None. 
            
            dtype(str|core.VarDesc.VarType|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.

            blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be 
              asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
            
        Returns:
            None

        Examples:
            .. code-block:: python

                import paddle

                linear=paddle.nn.Linear(2, 2)
                linear.weight
                #Parameter containing:
                #Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(dtype='float64')
                linear.weight
                #Tenor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])

                linear.to(device='cpu')
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CPUPlace, stop_gradient=False,
                #       [[-0.32770029,  0.38653070],
                #        [ 0.46030545,  0.08158520]])
                linear.to(device=paddle.CUDAPinnedPlace(), blocking=False)
                linear.weight
                #Tensor(shape=[2, 2], dtype=float64, place=CUDAPinnedPlace, stop_gradient=False,
                #       [[-0.04989364, -0.56889004],
                #        [ 0.33960250,  0.96878713]])
    

        '''

        if device is None and dtype is None and blocking is None:
            return

        if device is not None:
            if isinstance(device, str):
                device = paddle.device._convert_to_place(device)
            elif isinstance(device, (core.CPUPlace, core.CUDAPlace,
                                     core.CUDAPinnedPlace, core.XPUPlace)):
                pass
            else:
                raise ValueError(
                    "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
                    + type(device).__name__)

        if blocking is None:
            blocking = True
        else:
            assert isinstance(
                blocking,
                bool), "blocking value error, must be the True, False or None"

        def transform(t, device, dtype, blocking):
            if device is None:
                device = t.place
            if dtype is None:
                dtype = t.dtype

            new_t = t._copy_to(device, blocking)
            if dtype is not None and dtype != t.dtype:
                new_t = new_t.cast(dtype=dtype)

            return new_t

        self._apply(transform, device, dtype, blocking)

1437 1438 1439
    # [aliases] Compatible with old method names
    set_dict = set_state_dict
    load_dict = set_state_dict