container.py 19.1 KB
Newer Older
C
chentianyu03 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
C
chentianyu03 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
C
chentianyu03 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
C
chentianyu03 已提交
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
16
from collections.abc import Iterable, Mapping
17

18
from ...fluid.dygraph.base import param_guard
19
from ...fluid.framework import Parameter
20
from .layers import Layer
C
chentianyu03 已提交
21

22
__all__ = []
C
chentianyu03 已提交
23 24 25 26 27


class LayerDict(Layer):
    """
    LayerDict holds sublayers in the ordered dictionary, and sublayers it contains are properly registered.
28
    Holded sublayers can be accessed like a regular ordered python dictionary.
C
chentianyu03 已提交
29 30 31 32

    Parameters:
        sublayers (LayerDict|OrderedDict|list[(key,Layer)...], optional): iterable of key/value pairs, the type of value is 'paddle.nn.Layer' .

L
Li-fAngyU 已提交
33
    Examples:
C
chentianyu03 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        .. code-block:: python

            import paddle
            import numpy as np
            from collections import OrderedDict

            sublayers = OrderedDict([
                ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
            ])

            layers_dict = paddle.nn.LayerDict(sublayers=sublayers)

            l = layers_dict['conv1d']

            for k in layers_dict:
                l = layers_dict[k]

            len(layers_dict)
            #3

            del layers_dict['conv2d']
            len(layers_dict)
            #2

            conv1d = layers_dict.pop('conv1d')
            len(layers_dict)
            #1

            layers_dict.clear()
            len(layers_dict)
            #0

    """

    def __init__(self, sublayers=None):
71
        super().__init__()
C
chentianyu03 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        if sublayers is not None:
            self.update(sublayers)

    def __getitem__(self, key):
        return self._sub_layers[key]

    def __setitem__(self, key, sublayer):
        return self.add_sublayer(key, sublayer)

    def __delitem__(self, key):
        del self._sub_layers[key]

    def __len__(self):
        return len(self._sub_layers)

    def __iter__(self):
        return iter(self._sub_layers)

    def __contains__(self, key):
        return key in self._sub_layers

    def clear(self):
        """
        Clear all the sublayers in the LayerDict.

        Parameters:
            None.

L
Li-fAngyU 已提交
100
        Examples:
C
chentianyu03 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
            .. code-block:: python

                import paddle
                from collections import OrderedDict

                sublayers = OrderedDict([
                    ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                    ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                    ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
                ])

                layer_dict = paddle.nn.LayerDict(sublayers=sublayers)
                len(layer_dict)
                #3

                layer_dict.clear()
                len(layer_dict)
                #0

        """
        self._sub_layers.clear()

    def pop(self, key):
        """
        Remove the key from the LayerDict and return the layer of the key.

        Parameters:
            key (str): the key to be removed.

        Examples:
            .. code-block:: python

                import paddle
                from collections import OrderedDict

                sublayers = OrderedDict([
                    ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                    ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                    ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
                ])

                layer_dict = paddle.nn.LayerDict(sublayers=sublayers)
                len(layer_dict)
                #3

                layer_dict.pop('conv2d')
                len(layer_dict)
                #2

        """
        v = self[key]
        del self[key]
        return v

    def keys(self):
        """
        Return the iterable of the keys in LayerDict.

        Parameters:
            None.
161

C
chentianyu03 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        Examples:
            .. code-block:: python

                import paddle
                from collections import OrderedDict

                sublayers = OrderedDict([
                    ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                    ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                    ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
                ])

                layer_dict = paddle.nn.LayerDict(sublayers=sublayers)
                for k in layer_dict.keys():
                    print(k)
177

C
chentianyu03 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190
                #conv1d
                #conv2d
                #conv3d

        """
        return self._sub_layers.keys()

    def items(self):
        """
        Return the iterable of the key/value pairs in LayerDict.

        Parameters:
            None.
191

C
chentianyu03 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        Examples:
            .. code-block:: python

                import paddle
                from collections import OrderedDict

                sublayers = OrderedDict([
                    ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                    ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                    ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
                ])

                layer_dict = paddle.nn.LayerDict(sublayers=sublayers)
                for k, v in layer_dict.items():
                    print(k, ":", v)

                #conv1d : Conv1D(3, 2, kernel_size=[3], data_format=NCL)
                #conv2d : Conv2D(3, 2, kernel_size=[3, 3], data_format=NCHW)
                #conv3d : Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)

        """
        return self._sub_layers.items()

    def values(self):
        """
        Return the iterable of the values in LayerDict.

        Parameters:
            None.
221

C
chentianyu03 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
        Examples:
            .. code-block:: python

                import paddle
                from collections import OrderedDict

                sublayers = OrderedDict([
                    ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                    ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                    ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
                ])

                layer_dict = paddle.nn.LayerDict(sublayers=sublayers)
                for v in layer_dict.values():
                    print(v)

                #Conv1D(3, 2, kernel_size=[3], data_format=NCL)
                #Conv2D(3, 2, kernel_size=[3, 3], data_format=NCHW)
                #Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)

        """
        return self._sub_layers.values()

    def update(self, sublayers):
        """
        Update the key/values pairs in sublayers to the LayerDict, overwriting the existing keys.

        Parameters:
            sublayers (LayerDict|OrderedDict|list[(key,Layer)...]): iterable of key/value pairs, the type of value is 'paddle.nn.Layer' .
251

C
chentianyu03 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
        Examples:
            .. code-block:: python

                import paddle
                from collections import OrderedDict

                sublayers = OrderedDict([
                    ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
                    ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
                    ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
                ])

                new_sublayers = OrderedDict([
                    ('relu', paddle.nn.ReLU()),
                    ('conv2d', paddle.nn.Conv2D(4, 2, 4)),
                ])
                layer_dict = paddle.nn.LayerDict(sublayers=sublayers)

                layer_dict.update(new_sublayers)
271

C
chentianyu03 已提交
272 273 274 275 276 277 278 279 280
                for k, v in layer_dict.items():
                    print(k, ":", v)
                #conv1d : Conv1D(3, 2, kernel_size=[3], data_format=NCL)
                #conv2d : Conv2D(4, 2, kernel_size=[4, 4], data_format=NCHW)
                #conv3d : Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)
                #relu : ReLU()

        """

281 282 283 284
        assert isinstance(sublayers, Iterable), (
            "The type of sublayers is not iterable of key/value pairs, the type of sublayers is "
            + type(sublayers).__name__
        )
C
chentianyu03 已提交
285

286
        if isinstance(sublayers, (OrderedDict, LayerDict, Mapping)):
C
chentianyu03 已提交
287 288 289 290 291 292
            for key, layer in sublayers.items():
                self.add_sublayer(key, layer)
        else:
            # handle this format [(key1, layer1), (key2, layer2)...]
            for i, kv in enumerate(sublayers):
                if len(kv) != 2:
293 294 295 296 297 298 299
                    raise ValueError(
                        "The length of the "
                        + str(i)
                        + "'s element in sublayers is "
                        + str(len(kv))
                        + ", which must be 2."
                    )
C
chentianyu03 已提交
300
                self.add_sublayer(kv[0], kv[1])
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534


class ParameterList(Layer):
    """ParameterList Container.

    This container acts like a Python list, but parameters it contains will be properly added.

    Parameters:
        parameters (iterable, optional): Iterable Parameters to be added

    Examples:
        .. code-block:: python

            import paddle

            class MyLayer(paddle.nn.Layer):
                def __init__(self, num_stacked_param):
                    super().__init__()
                    # create ParameterList with iterable Parameters
                    self.params = paddle.nn.ParameterList(
                        [paddle.create_parameter(
                            shape=[2, 2], dtype='float32')] * num_stacked_param)

                def forward(self, x):
                    for i, p in enumerate(self.params):
                        tmp = self._helper.create_variable_for_type_inference('float32')
                        self._helper.append_op(
                            type="mul",
                            inputs={"X": x,
                                    "Y": p},
                            outputs={"Out": tmp},
                            attrs={"x_num_col_dims": 1,
                                    "y_num_col_dims": 1})
                        x = tmp
                    return x

            x = paddle.uniform(shape=[5, 2], dtype='float32')
            num_stacked_param = 4
            model = MyLayer(num_stacked_param)
            print(len(model.params))  # 4
            res = model(x)
            print(res.shape)  # [5, 2]

            replaced_param = paddle.create_parameter(shape=[2, 3], dtype='float32')
            model.params[num_stacked_param - 1] = replaced_param  # replace last param
            res = model(x)
            print(res.shape)  # [5, 3]
            model.params.append(paddle.create_parameter(shape=[3, 4], dtype='float32'))  # append param
            print(len(model.params))  # 5
            res = model(x)
            print(res.shape)  # [5, 4]
    """

    def __init__(self, parameters=None):
        super().__init__()
        if parameters is not None:
            for idx, param in enumerate(parameters):
                assert isinstance(param, Parameter)
                self.add_parameter(str(idx), param)

    def __getitem__(self, idx):
        with param_guard(self._parameters):
            return self._parameters[str(idx)]

    def __setitem__(self, idx, param):
        assert isinstance(param, Parameter)
        setattr(self, str(idx), param)

    def __len__(self):
        return len(self._parameters)

    def __iter__(self):
        with param_guard(self._parameters):
            return iter(self._parameters.values())

    def append(self, parameter):
        """Appends a given parameter at the end of the list.

        Parameters:
            parameter (Parameter): parameter to append
        """
        idx = len(self._parameters)
        self.add_parameter(str(idx), parameter)
        return self


class LayerList(Layer):
    """
    LayerList holds sublayers, and sublayers it contains are properly registered.
    Holded sublayers can be indexed like a regular python list.

    Parameters:
        sublayers (iterable of Layer, optional): sublayers to hold

    Examples:
        .. code-block:: python

            import paddle

            class MyLayer(paddle.nn.Layer):
                def __init__(self):
                    super().__init__()
                    self.linears = paddle.nn.LayerList(
                        [paddle.nn.Linear(10, 10) for i in range(10)])

                def forward(self, x):
                    # LayerList can act as an iterable, or be indexed using ints
                    for i, l in enumerate(self.linears):
                        x = self.linears[i // 2](x) + l(x)
                    return x
    """

    def __init__(self, sublayers=None):
        super().__init__()
        if sublayers is not None:
            for idx, layer in enumerate(sublayers):
                self.add_sublayer(str(idx), layer)

    def _get_abs_idx(self, idx):
        if isinstance(idx, int):
            if not (-len(self) <= idx < len(self)):
                raise IndexError(
                    'index {} is out of range, should be an integer in range [{}, {})'.format(
                        idx, -len(self), len(self)
                    )
                )
            if idx < 0:
                idx += len(self)
        return idx

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return self.__class__(list(self._sub_layers.values())[idx])
        else:
            idx = self._get_abs_idx(idx)
            return self._sub_layers[str(idx)]

    def __setitem__(self, idx, sublayer):
        idx = self._get_abs_idx(idx)
        return setattr(self, str(idx), sublayer)

    def __delitem__(self, idx):
        if isinstance(idx, slice):
            for k in range(len(self._sub_layers))[idx]:
                delattr(self, str(k))
        else:
            idx = self._get_abs_idx(idx)
            delattr(self, str(idx))
        str_indices = [str(i) for i in range(len(self._sub_layers))]
        self._sub_layers = OrderedDict(
            list(zip(str_indices, self._sub_layers.values()))
        )

    def __len__(self):
        return len(self._sub_layers)

    def __iter__(self):
        return iter(self._sub_layers.values())

    def append(self, sublayer):
        """
        Appends a sublayer to the end of the list.

        Parameters:
            sublayer (Layer): sublayer to append

        Examples:
            .. code-block:: python

                import paddle

                linears = paddle.nn.LayerList([paddle.nn.Linear(10, 10) for i in range(10)])
                another = paddle.nn.Linear(10, 10)
                linears.append(another)
                print(len(linears))  # 11
        """
        self.add_sublayer(str(len(self)), sublayer)
        return self

    def insert(self, index, sublayer):
        """
        Insert a sublayer before a given index in the list.

        Parameters:
            index (int): index to insert.
            sublayer (Layer): sublayer to insert

        Examples:
            .. code-block:: python

                import paddle

                linears = paddle.nn.LayerList([paddle.nn.Linear(10, 10) for i in range(10)])
                another = paddle.nn.Linear(10, 10)
                linears.insert(3, another)
                print(linears[3] is another)  # True
                another = paddle.nn.Linear(10, 10)
                linears.insert(-1, another)
                print(linears[-2] is another) # True
        """
        assert isinstance(index, int) and -len(self._sub_layers) <= index < len(
            self._sub_layers
        ), "index should be an integer in range [{}, {})".format(
            -len(self), len(self)
        )

        index = self._get_abs_idx(index)
        for i in range(len(self._sub_layers), index, -1):
            self._sub_layers[str(i)] = self._sub_layers[str(i - 1)]
        self._sub_layers[str(index)] = sublayer

    def extend(self, sublayers):
        """
        Appends sublayers to the end of the list.

        Parameters:
            sublayers (iterable of Layer): iterable of sublayers to append

        Examples:
            .. code-block:: python

                import paddle

                linears = paddle.nn.LayerList([paddle.nn.Linear(10, 10) for i in range(10)])
                another_list = paddle.nn.LayerList([paddle.nn.Linear(10, 10) for i in range(5)])
                linears.extend(another_list)
                print(len(linears))  # 15
                print(another_list[0] is linears[10])  # True
        """
        offset = len(self)
        for i, sublayer in enumerate(sublayers):
            idx = str(offset + i)
            self.add_sublayer(idx, sublayer)
        return self
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584


class Sequential(Layer):
    """Sequential container.
    Sub layers will be added to this container in the order of argument in the constructor.
    The argument passed to the constructor can be iterable Layers or iterable name Layer pairs.

    Parameters:
        layers(Layer|list|tuple): Layer or list/tuple of iterable name Layer pair.

    Examples:
        .. code-block:: python

            import paddle

            data = paddle.uniform(shape=[30, 10], dtype='float32')
            # create Sequential with iterable Layers
            model1 = paddle.nn.Sequential(
                paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2)
            )
            model1[0]  # access the first layer
            res1 = model1(data)  # sequential execution

            # create Sequential with name Layer pairs
            model2 = paddle.nn.Sequential(
                ('l1', paddle.nn.Linear(10, 2)),
                ('l2', paddle.nn.Linear(2, 3))
            )
            model2['l1']  # access l1 layer
            model2.add_sublayer('l3', paddle.nn.Linear(3, 3))  # add sublayer
            res2 = model2(data)  # sequential execution

    """

    def __init__(self, *layers):
        super().__init__()
        if len(layers) > 0 and isinstance(layers[0], (list, tuple)):
            for name, layer in layers:
                self.add_sublayer(name, layer)
        else:
            for idx, layer in enumerate(layers):
                self.add_sublayer(str(idx), layer)

    def __getitem__(self, name):
        if isinstance(name, slice):
            return self.__class__(*(list(self._sub_layers.values())[name]))
        elif isinstance(name, str):
            return self._sub_layers[name]
        else:
            if name >= len(self._sub_layers):
585
                raise IndexError(f'index {name} is out of range')
586 587 588
            elif name < 0 and name >= -len(self._sub_layers):
                name += len(self._sub_layers)
            elif name < -len(self._sub_layers):
589
                raise IndexError(f'index {name} is out of range')
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
            return list(self._sub_layers.values())[name]

    def __setitem__(self, name, layer):
        assert isinstance(layer, Layer)
        setattr(self, str(name), layer)

    def __delitem__(self, name):
        name = str(name)
        assert name in self._sub_layers
        del self._sub_layers[name]

    def __len__(self):
        return len(self._sub_layers)

    def forward(self, input):
        for layer in self._sub_layers.values():
            input = layer(input)
        return input