common.py 74.7 KB
Newer Older
S
shiyutang 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
S
shiyutang 已提交
7
#    http://www.apache.org/licenses/LICENSE-2.0
8 9 10 11 12 13 14
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# TODO: define the common classes to build a neural network
16
import paddle
17 18
from paddle import in_dynamic_mode

19
from .. import functional as F
20
from .layers import Layer
21

22 23
__all__ = []

24

25
def _npairs(x, n):
26
    if isinstance(x, (paddle.Tensor, list, tuple)):
27 28 29 30 31
        return x
    x = [x] * (n * 2)
    return x


S
shiyutang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
class Identity(Layer):
    r"""

    A placeholder identity operator that is argument-insensitive. For each input :math:`X` ,
    the output :math:`Out` is:

    .. math::

        Out = X

    Parameters:
        args: any argument (unused)
        kwargs: any keyword argument (unused)

    Shape:
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, n1, n2, ...]` .
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, n1, n2, ...]` .

    Examples:
        .. code-block:: python

          import paddle

          input_tensor = paddle.randn(shape=[3, 2])
          layer = paddle.nn.Identity()
          out = layer(input_tensor)
          # input_tensor: [[-0.32342386 -1.200079  ]
          #                [ 0.7979031  -0.90978354]
          #                [ 0.40597573  1.8095392 ]]
          # out: [[-0.32342386 -1.200079  ]
          #      [ 0.7979031  -0.90978354]
          #      [ 0.40597573  1.8095392 ]]


    """

    def __init__(self, *args, **kwargs):
69
        super().__init__()
S
shiyutang 已提交
70 71 72 73 74

    def forward(self, input):
        return input


Z
zhiboniu 已提交
75
class Linear(Layer):
76
    r"""
77 78 79

    Fully-connected linear transformation layer. For each input :math:`X` ,
    the equation is:
80 81 82

    .. math::

83
        Out = XW + b
84

85
    where :math:`W` is the weight and :math:`b` is the bias.
86

87 88 89 90 91 92 93
    Linear layer takes only one multi-dimensional tensor as input with the
    shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
    number of additional dimensions. It multiplies input tensor with the weight
    (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
    an output tensor of shape :math:`[batch\_size, *, out\_features]` .
    If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
    shape :math:`[out\_features]` ) will be created and added to the output.
94 95

    Parameters:
96 97 98
        in_features (int): The number of input units.
        out_features (int): The number of output units.
        weight_attr (ParamAttr, optional): The attribute for the learnable
99 100 101
            weight of this layer. The default value is None. If the Initializer of the
            param_attr is not set, the parameter is initialized with Xavier.
            For detailed information, please refer to paddle.ParamAttr.
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
            of this layer. If it is set to False, no bias will be added to the output.
            If it is set to None or one kind of ParamAttr, a bias parameter will
            be created according to ParamAttr. For detailed information, please refer
            to paddle.ParamAttr. The default value is None and the bias will be
            initialized to zero.
        name (str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .

    Attribute:
        **weight** (Parameter): the learnable weight of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Shape:
117 118
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` . Its data types are float16, float32, float64 ,The default is float32 .
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` . The data type is the same as the input .
119 120 121 122 123

    Examples:
        .. code-block:: python

          import paddle
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

          # Define the linear layer.
          weight_attr = paddle.ParamAttr(
              name="weight",
              initializer=paddle.nn.initializer.Constant(value=0.5))
          bias_attr = paddle.ParamAttr(
              name="bias",
              initializer=paddle.nn.initializer.Constant(value=1.0))
          linear = paddle.nn.Linear(2, 4, weight_attr=weight_attr, bias_attr=bias_attr)
          # linear.weight: [[0.5 0.5 0.5 0.5]
          #                 [0.5 0.5 0.5 0.5]]
          # linear.bias: [1. 1. 1. 1.]

          x = paddle.randn((3, 2), dtype="float32")
          # x: [[-0.32342386 -1.200079  ]
          #     [ 0.7979031  -0.90978354]
          #     [ 0.40597573  1.8095392 ]]
          y = linear(x)
          # y: [[0.23824859 0.23824859 0.23824859 0.23824859]
          #     [0.9440598  0.9440598  0.9440598  0.9440598 ]
          #     [2.1077576  2.1077576  2.1077576  2.1077576 ]]
145 146
    """

147 148 149 150 151 152 153 154
    def __init__(
        self,
        in_features,
        out_features,
        weight_attr=None,
        bias_attr=None,
        name=None,
    ):
155
        super().__init__()
156 157 158
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
159 160 161 162 163 164 165 166 167 168 169 170
        self.weight = self.create_parameter(
            shape=[in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False,
        )
        self.bias = self.create_parameter(
            shape=[out_features],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
171 172 173
        self.name = name

    def forward(self, input):
174 175 176
        out = F.linear(
            x=input, weight=self.weight, bias=self.bias, name=self.name
        )
177 178
        return out

179
    def extra_repr(self):
180
        name_str = f', name={self.name}' if self.name else ''
181
        return 'in_features={}, out_features={}, dtype={}{}'.format(
182 183
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str
        )
184

185

FormlessUnit's avatar
FormlessUnit 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
class LinearCompress(Layer):
    r"""

    Fully-connected linear transformation layer. For each input :math:`X` ,
    the equation is:

    .. math::

        Out = XW + b

    where :math:`W` is the weight and :math:`b` is the bias.

    Linear layer takes only one multi-dimensional tensor as input with the
    shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
    number of additional dimensions. It multiplies input tensor with the weight
    (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
    an output tensor of shape :math:`[batch\_size, *, out\_features]` .
    If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
    shape :math:`[out\_features]` ) will be created and added to the output.

    Parameters:
        in_features (int): The number of input units.
        out_features (int): The number of output units.
        weight_attr (ParamAttr, optional): The attribute for the weight of this layer.
            The default value is None. If the Initializer of the
            param_attr is not set, the parameter is initialized with Xavier.
            For detailed information, please refer to paddle.ParamAttr.
        bias_attr (ParamAttr|bool, optional): The attribute for the bias of this layer.
            If it is set to False, no bias will be added to the output.
            If it is set to None or one kind of ParamAttr, a bias parameter will
            be created according to ParamAttr. For detailed information, please refer
            to paddle.ParamAttr. The default value is None and the bias will be
            initialized to zero.
        name (str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .
        bits (int, optional): The attribute to set num of bits in quant during weight_only,
            it must be set as 8, default: 8.
        algo (str, optional): The  attribute to set algorithm of cpmoress, it must be set as 'weight_only'
            or 'llm.int8', default: weight_only.
        config (dict, optional): The parameter config for algorithm of cpmoress.
            For llm.int8, it should be set as {'threshold': 6.0}, default: {'threshold': 6.0}.

    Attribute:
        **weight** (Parameter): the learnable weight of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Shape:
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` . Its data types are float16.
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` . The data type is the same as the input .

    Examples:
        .. code-block:: python

          import paddle

          # Define the linear layer.
          paddle.set_default_dtype('float16')
          weight_attr = paddle.ParamAttr(
              name="weight",
              initializer=paddle.nn.initializer.Constant(value=0.5))
          bias_attr = paddle.ParamAttr(
              name="bias",
              initializer=paddle.nn.initializer.Constant(value=1.0))
          linear = paddle.nn.LinearCompress(128, 64, weight_attr=weight_attr, bias_attr=bias_attr, bits=8, algo='weight_only')
          x = paddle.randn((3, 128), dtype="float16")
          y = linear(x)
    """

    def __init__(
        self,
        in_features,
        out_features,
        weight_attr=None,
        bias_attr=None,
        name=None,
        bits=8,
        algo="weight_only",
        config={'threshold': 6.0},
    ):
        super().__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self.weight = self.create_parameter(
            shape=[in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False,
        )
        self.bias = self.create_parameter(
            shape=[out_features],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
        self.weight_scale = self.create_parameter(
            shape=[out_features],
            attr=None,
            dtype=self._dtype,
            is_bias=False,
        )
        self.is_weight_quanted = False
        self.name = (name,)
        self.bits = bits
        self.layout = algo
        self.algo = algo
        self.config = config

    def forward(self, input):
        if in_dynamic_mode():
            if not self.is_weight_quanted:
                weight_tensor, weight_scale_tensor = F.quant_for_compress(
                    self.weight, self.bits, self.layout
                )
                weight_attr = paddle.framework.ParamAttr(
                    initializer=paddle.nn.initializer.Assign(weight_tensor)
                )
                self.weight = self.create_parameter(
                    shape=self.weight.shape
                    if self.layout == 0
                    else [self.weight.shape[1], self.weight.shape[0]],
                    attr=weight_attr,
                    dtype="int8",
                    is_bias=False,
                )
                weight_scale_attr = paddle.framework.ParamAttr(
                    initializer=paddle.nn.initializer.Assign(
                        weight_scale_tensor
                    )
                )
                self.weight_scale = self.create_parameter(
                    shape=self.weight_scale.shape,
                    attr=weight_scale_attr,
                    dtype="float32",
                    is_bias=False,
                )
                self.is_weight_quanted = True
            out = F.linear_compress(
                x=input,
                weight=self.weight,
                weight_scale=self.weight_scale,
                bias=self.bias,
                bits=self.bits,
                algo=self.algo,
                name=self.name,
                config=self.config,
            )
            return out

    def extra_repr(self):
        name_str = f', name={self.name}' if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}, algo={}'.format(
            self.weight.shape[0],
            self.weight.shape[1],
            self._dtype,
            name_str,
            self.algo,
        )


Z
zhiboniu 已提交
347
class Upsample(Layer):
348 349
    """
    This op resizes a batch of images.
350

351 352 353
    The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
    or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
    (num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
354 355
    Where in_w is width of the input tensor, in_h is the height of the input tensor,
    in_d is the depth of the intput tensor.
356
    and the resizing only applies on the three dimensions(depth, height and width).
X
xiaoting 已提交
357

358
    Supporting resample methods:
359 360 361 362 363 364
        'linear' : Linear interpolation
        'bilinear' : Bilinear interpolation
        'trilinear' : Trilinear interpolation
        'nearest' : Nearest neighbor interpolation
        'bicubic' : Bicubic interpolation

T
tangwei12 已提交
365 366 367
    Linear interpolation is the method of using a line connecting two known quantities
    to determine the value of an unknown quantity between the two known quantities.

368 369 370 371 372 373 374 375 376
    Nearest neighbor interpolation is to perform nearest neighbor interpolation
    in both the 3rd dimension(in height direction) and the 4th dimension(in width
    direction) on input tensor.

    Bilinear interpolation is an extension of linear interpolation for
    interpolating functions of two variables (e.g. H-direction and
    W-direction in this op) on a rectilinear 2D grid. The key idea is
    to perform linear interpolation first in one direction, and then
    again in the other direction.
T
tangwei12 已提交
377

378 379 380 381
    Bicubic interpolation is an extension of cubic interpolation for interpolating
    data points on a two-dimensional regular grid. The interpolated surface is
    smoother than corresponding surfaces obtained by bilinear interpolation or
    nearest-neighbor interpolation.
382 383 384 385 386

    Trilinear interpolation is an extension of linear interpolation for
    interpolating functions of three variables (e.g. D-direction,
    H-direction and W-direction in this op) on a rectilinear 3D grid.
    The linear interpolation is performed on three directions.
X
xiaoting 已提交
387
    align_corners and align_mode are optional parameters,the calculation method
388 389
    of interpolation can be selected by them.

390 391 392 393 394 395
    Area interpolation is to perform area interpolation
    in both the 3rd dimension(in height direction) , the 4th dimension(in width
    direction) and the 5th dimension(in depth direction) on input tensor. Set to
    area will directly call `paddle.nn.functional.adaptive_avg_pool1d` or
    `paddle.nn.functional.adaptive_avg_pool2d` or `paddle.nn.functional.adaptive_avg_pool3d`.

396 397 398 399
    Example:

    .. code-block:: text

400
        For scale_factor:
401 402 403 404 405
            if align_corners = True && out_size > 1 :
              scale_factor = (in_size-1.0)/(out_size-1.0)
            else:
              scale_factor = float(in_size/out_size)

406 407 408 409 410 411 412 413 414 415
        Linear interpolation:
            if:
                align_corners = False , align_mode = 0
                input : (N,C,W_in)
                output: (N,C,W_out) where:
                W_out = (W_{in}+0.5) * scale_{factor} - 0.5
            else:
                input : (N,C,W_in)
                output: (N,C,W_out) where:
                W_out = W_{in} * scale_{factor}
416 417 418 419 420 421 422 423 424 425 426 427 428 429

        Nearest neighbor interpolation:
          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = floor (H_{in} * scale_{factor})
              W_out = floor (W_{in} * scale_{factor})
          else:
              align_corners = True
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})
T
tangwei12 已提交
430

431 432 433
        Bilinear interpolation:
          if:
              align_corners = False , align_mode = 0
434

435 436 437 438 439
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
440

441 442 443 444 445 446 447 448 449 450 451 452
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

        Bicubic interpolation:
          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
453

454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
          else:
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

        Trilinear interpolation:
          if:
              align_corners = False , align_mode = 0
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

475 476
    https://en.wikipedia.org/wiki/Linear_interpolation.
    For details of linear interpolation, please refer to Wikipedia:
T
tangwei12 已提交
477

478 479
    For details of nearest neighbor interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
T
tangwei12 已提交
480

481 482
    For details of bilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bilinear_interpolation.
T
tangwei12 已提交
483

484 485
    For details of bicubic interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bicubic_interpolation
T
tangwei12 已提交
486

487 488
    For details of trilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Trilinear_interpolation.
T
tangwei12 已提交
489

490
    Parameters:
X
xiaoting 已提交
491
        x (Tensor): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
492
                          its data format is specified by :attr:`data_format`.
X
xiaoting 已提交
493
        size (list|tuple|Tensor|None): Output shape of image resize
494 495
             layer, the shape is (out_w, ) when input is a 3-D Tensor, the shape is (out_h, out_w)
             when input is a 4-D Tensor and is (out_d, out_h, out_w) when input is a 5-D Tensor.
496
             Default: None. If a list/tuple, each element can be an integer or a Tensor of shape: [1].
X
xiaoting 已提交
497
             If a Tensor , its dimensions size should be a 1.
498 499 500
        scale_factor (float|Tensor|list|tuple|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`. Has to match input size if it is either a list or a tuple or a Tensor.
501
             Default: None.
502 503
        mode (str): The resample method. It supports 'linear', 'nearst', 'bilinear',
                       'bicubic' and 'trilinear' currently. Default: 'nearest'
504 505 506
        align_corners(bool) :  An optional bool, If True, the centers of the 4 corner pixels of the
                               input and output tensors are aligned, preserving the values at the
                               corner pixels.
507 508 509 510
                               Default: False
        align_mode(int)  :  An optional for linear/bilinear/trilinear interpolation. Refer to the formula in the example above,
                            it can be \'0\' for src_idx = scale_factor*(dst_indx+0.5)-0.5 , can be \'1\' for
                            src_idx = scale_factor*dst_index.
511
        data_format (str, optional): Specify the data format of the input, and the data format of the output
512
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
513 514 515
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
516 517 518
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
519 520 521
    Returns:
        A 3-D Tensor of the shape (num_batches, channels, out_w) or (num_batches, out_w, channels),
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
522
        or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
523 524 525

    Examples:
        .. code-block:: python
526

527
            import paddle
X
xiaoting 已提交
528

529 530
            input = paddle.rand([2,3,6,10], dtype="float32")
            upsample_out = paddle.nn.Upsample(size=[12,12])
X
xiaoting 已提交
531 532 533

            output = upsample_out(x=input)
            print(output.shape)
534
            # [2, 3, 12, 12]
X
xiaoting 已提交
535

536 537
    """

538 539 540 541 542 543 544 545 546 547
    def __init__(
        self,
        size=None,
        scale_factor=None,
        mode='nearest',
        align_corners=False,
        align_mode=0,
        data_format='NCHW',
        name=None,
    ):
548
        super().__init__()
549 550 551
        self.size = size
        self.scale_factor = scale_factor
        self.mode = mode.lower()
552 553 554
        self.align_corners = align_corners
        self.align_mode = align_mode
        self.data_format = data_format
X
xiaoting 已提交
555
        self.name = name
556

X
xiaoting 已提交
557
    def forward(self, x):
558 559 560 561 562 563 564 565 566 567
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale_factor,
            mode=self.mode,
            align_corners=self.align_corners,
            align_mode=self.align_mode,
            data_format=self.data_format,
            name=self.name,
        )
X
xiaoting 已提交
568 569 570

        return out

571 572
    def extra_repr(self):
        if self.scale_factor is not None:
573
            main_str = f'scale_factor={self.scale_factor}'
574
        else:
575 576
            main_str = f'size={self.size}'
        name_str = f', name={self.name}' if self.name else ''
577
        return '{}, mode={}, align_corners={}, align_mode={}, data_format={}{}'.format(
578 579 580 581 582 583 584
            main_str,
            self.mode,
            self.align_corners,
            self.align_mode,
            self.data_format,
            name_str,
        )
585

X
xiaoting 已提交
586

Z
zhiboniu 已提交
587
class UpsamplingNearest2D(Layer):
X
xiaoting 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
    """
    This op upsamples a batch of images, using nearest neighbours' pixel values.
    The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
    where in_w is width of the input tensor, in_h is the height of the input tensor.
    And the upsampling only applies on the two dimensions(height and width).
    Nearest neighbor interpolation is to perform nearest neighbor interpolation
    in both the 3rd dimension(in height direction) and the 4th dimension(in width
    direction) on input tensor.

    For details of nearest neighbor interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.

    Parameters:
        x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
                          its data format is specified by :attr:`data_format`.
        size (list|tuple|Tensor|None): Output shape of image resize
             layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
605
             Default: None. If a list/tuple, each element can be an integer or a Tensor of shape: [1].
X
xiaoting 已提交
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
             If a Tensor , its dimensions size should be a 1.
        scale_factor (float|int|list|tuple|Tensor|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`.
             Has to match input size if it is either a list or a tuple or a Tensor.
             Default: None.
        data_format (str, optional): Specify the data format of the input, and the data format of the output
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
    Returns:
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),


    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
630
            input_data = paddle.rand(shape=(2,3,6,10)).astype("float32")
X
xiaoting 已提交
631 632 633 634 635 636 637
            upsample_out  = paddle.nn.UpsamplingNearest2D(size=[12,12])
            input = paddle.to_tensor(input_data)
            output = upsample_out(x=input)
            print(output.shape)
            # [2L, 3L, 12L, 12L]
    """

638 639 640
    def __init__(
        self, size=None, scale_factor=None, data_format='NCHW', name=None
    ):
641
        super().__init__()
X
xiaoting 已提交
642 643 644 645 646 647
        self.size = size
        self.scale_factor = scale_factor
        self.data_format = data_format
        self.name = name

    def forward(self, x):
648 649 650 651 652 653 654 655 656 657
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale_factor,
            mode='nearest',
            align_corners=False,
            align_mode=0,
            data_format=self.data_format,
            name=self.name,
        )
X
xiaoting 已提交
658 659 660

        return out

661 662
    def extra_repr(self):
        if self.scale_factor is not None:
663
            main_str = f'scale_factor={self.scale_factor}'
664
        else:
665 666
            main_str = f'size={self.size}'
        name_str = f', name={self.name}' if self.name else ''
667 668 669
        return '{}, data_format={}{}'.format(
            main_str, self.data_format, name_str
        )
670

X
xiaoting 已提交
671

Z
zhiboniu 已提交
672
class UpsamplingBilinear2D(Layer):
X
xiaoting 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
    """
    This op upsamples a batch of images, using bilinear' pixel values.
    The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
    where in_w is width of the input tensor, in_h is the height of the input tensor.
    And the upsampling only applies on the two dimensions(height and width).
    Bilinear interpolation is an extension of linear interpolation for
    interpolating functions of two variables (e.g. H-direction and
    W-direction in this op) on a rectilinear 2D grid. The key idea is
    to perform linear interpolation first in one direction, and then
    again in the other direction.

    For details of bilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bilinear_interpolation.

    Parameters:
        x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
                          its data format is specified by :attr:`data_format`.
        size (list|tuple|Tensor|None): Output shape of image resize
             layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
692
             Default: None. If a list/tuple, each element can be an integer or a Tensor  of shape: [1].
X
xiaoting 已提交
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715
             If a Tensor , its dimensions size should be a 1.
        scale_factor (float|int|list|tuple|Tensor|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`.
             Has to match input size if it is either a list or a tuple or a Tensor.
             Default: None.
        data_format (str, optional): Specify the data format of the input, and the data format of the output
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
    Returns:
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
716
            input_data = paddle.rand(shape=(2,3,6,10)).astype("float32")
X
xiaoting 已提交
717 718 719 720 721 722 723
            upsample_out  = paddle.nn.UpsamplingBilinear2D(size=[12,12])
            input = paddle.to_tensor(input_data)
            output = upsample_out(x=input)
            print(output.shape)
            # [2L, 3L, 12L, 12L]
    """

724 725 726
    def __init__(
        self, size=None, scale_factor=None, data_format='NCHW', name=None
    ):
727
        super().__init__()
X
xiaoting 已提交
728 729 730 731 732 733
        self.size = size
        self.scale_factor = scale_factor
        self.data_format = data_format
        self.name = name

    def forward(self, x):
734 735 736 737 738 739 740 741 742 743
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale_factor,
            mode='bilinear',
            align_corners=True,
            align_mode=0,
            data_format=self.data_format,
            name=self.name,
        )
X
xiaoting 已提交
744 745 746

        return out

747 748
    def extra_repr(self):
        if self.scale_factor is not None:
749
            main_str = f'scale_factor={self.scale_factor}'
750
        else:
751 752
            main_str = f'size={self.size}'
        name_str = f', name={self.name}' if self.name else ''
753 754 755
        return '{}, data_format={}{}'.format(
            main_str, self.data_format, name_str
        )
756

X
xiaoting 已提交
757

Z
zhiboniu 已提交
758
class Bilinear(Layer):
759
    r"""
760 761 762 763

    This layer performs bilinear on two inputs.

    .. math::
764

765
      out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,outfeatures-1
766

767 768 769 770 771 772
      out = out + b

    In this formula:
     - :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features].
     - :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features].
     - :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features].
773
     - :math:`out_{i}`: the i-th element of out, shape is [batch_size], and out's shape is [batch_size, out_features].
774 775 776 777 778 779 780
     - :math:`b`: the learned bias, shape is [1, out_features].
     - :math:`x2^\mathrm{T}`: the transpose of :math:`x2`.

    Parameters:
       in1_features (int): The dimension of each first input(`x1`).
       in2_features (int): The dimension of each second input(`x2`).
       out_features (int): The dimension of output of this layer.
T
tangwei12 已提交
781
       weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of
782 783 784
       this layer. The default value is None.
       bias_attr (ParamAttr, optional): The parameter attribute for the bias
           of this layer. If it is set to False, no bias will be added to the output units.
T
tangwei12 已提交
785
           If it is set to None, the bias is initialized zero. The default value is None.
786 787 788 789 790 791 792 793 794
       name (str, optional): The default value is None. Normally there is no need for user
           to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

    Attribute:
        **weight** (Parameter): the learnable weights of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Returns:
795
       Tensor: A 2-D Tensor of shape [batch_size, out_features].
796 797 798 799 800 801

    Examples:
       .. code-block:: python

        import paddle

802 803
        layer1 = paddle.rand((5, 5)).astype('float32')
        layer2 = paddle.rand((5, 4)).astype('float32')
804 805
        bilinear = paddle.nn.Bilinear(
            in1_features=5, in2_features=4, out_features=1000)
806
        result = bilinear(layer1,layer2)    # result shape [5, 1000]
807 808 809

    """

810 811 812 813 814 815 816 817 818
    def __init__(
        self,
        in1_features,
        in2_features,
        out_features,
        weight_attr=None,
        bias_attr=None,
        name=None,
    ):
819
        super().__init__()
820 821 822 823 824 825 826 827 828
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._name = name
        self._in1_features = in1_features
        self._in2_features = in2_features
        self._out_features = out_features
        self._dtype = self._helper.get_default_dtype()

        weight_shape = [
829 830 831
            self._out_features,
            self._in1_features,
            self._in2_features,
832
        ]
833 834 835 836 837 838
        self.weight = self.create_parameter(
            attr=self._weight_attr,
            shape=weight_shape,
            dtype=self._dtype,
            is_bias=False,
        )
839
        bias_shape = [1, self._out_features]
840 841 842 843 844 845
        self.bias = self.create_parameter(
            attr=self._bias_attr,
            shape=bias_shape,
            dtype=self._dtype,
            is_bias=True,
        )
846 847 848 849

    def forward(self, x1, x2):
        return F.bilinear(x1, x2, self.weight, self.bias, self._name)

850
    def extra_repr(self):
851
        name_str = f', name={self._name}' if self._name else ''
852
        return 'in1_features={}, in2_features={}, out_features={}, dtype={}{}'.format(
853 854 855 856 857 858
            self._in1_features,
            self._in2_features,
            self._out_features,
            self._dtype,
            name_str,
        )
859

860

Z
zhiboniu 已提交
861
class Dropout(Layer):
862
    r"""
863 864
    Dropout is a regularization technique for reducing overfitting by preventing
    neuron co-adaption during training as described in the paper:
T
tangwei12 已提交
865
    `Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/abs/1207.0580>`_
866 867 868
    The dropout operator randomly sets the outputs of some units to zero, while upscale others
    according to the given dropout probability.

869
    See :ref:`api_paddle_nn_functional_dropout` for more details.
870 871

    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
872 873

    Parameters:
874 875
        p (float|int, optional): Probability of setting units to zero. Default: 0.5
        axis (int|list|tuple, optional): The axis along which the dropout is performed. Default: None.
876 877
        mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

878
                               1. upscale_in_train (default), upscale the output at training time
879

880 881
                                  - train: :math:`out = input \times \frac{mask}{(1.0 - p)}`
                                  - inference: :math:`out = input`
882 883 884

                               2. downscale_in_infer, downscale the output at inference

885 886 887
                                  - train: :math:`out = input \times mask`
                                  - inference: :math:`out = input \times (1.0 - p)`
        name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
888 889 890 891 892

    Shape:
        - input: N-D tensor.
        - output: N-D tensor, the same shape as input.

893

894 895
    Examples:
        .. code-block:: python
896

897 898
            import paddle

899
            x = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32")
900
            m = paddle.nn.Dropout(p=0.5)
901

902
            y_train = m(x)
903 904 905 906 907
            print(y_train)
            # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[2., 0., 6.],
            #         [0., 0., 0.]])

908 909
            m.eval()  # switch the model to test phase
            y_test = m(x)
910
            print(y_test)
911 912 913
            # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[1., 2., 3.],
            #         [4., 5., 6.]])
914
    """
915 916

    def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None):
917
        super().__init__()
918 919 920 921 922 923 924

        self.p = p
        self.axis = axis
        self.mode = mode
        self.name = name

    def forward(self, input):
925 926 927 928 929 930 931 932
        out = F.dropout(
            input,
            p=self.p,
            axis=self.axis,
            training=self.training,
            mode=self.mode,
            name=self.name,
        )
933 934
        return out

935
    def extra_repr(self):
936
        name_str = f', name={self.name}' if self.name else ''
937 938 939
        return 'p={}, axis={}, mode={}{}'.format(
            self.p, self.axis, self.mode, name_str
        )
940

941

Z
zhiboniu 已提交
942
class Dropout2D(Layer):
943 944 945 946
    """
    Randomly zero out entire channels (in the batched input 4d tensor with the shape `NCHW` ,
    a channel is a 2D feature map with the shape `HW`). Each channel will be zeroed out independently
    on every forward call with probability `p` using samples from a Bernoulli distribution.
C
cnn 已提交
947
    Dropout2D will help promote independence between feature maps as described in the paper:
T
tangwei12 已提交
948
    `Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
949

950
    See :ref:`api_paddle_nn_functional_dropout2d` for more details.
951

952 953
    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

954
    Parameters:
955 956 957
        p (float, optional): Probability of setting units to zero. Default: 0.5.
        data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from `NCHW` or `NHWC`. When it is `NCHW`, the data is stored in the order of: [batch_size, input_channels, input_height, input_width]. Default: `NCHW`.
        name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
958 959 960 961 962

    Shape:
        - input: 4-D tensor.
        - output: 4-D tensor, the same shape as input.

963

964 965
    Examples:
        .. code-block:: python
966

967 968
            import paddle

969 970 971 972 973 974 975 976 977
            x = paddle.rand([2, 2, 1, 3], dtype="float32")
            print(x)
            # Tensor(shape=[2, 2, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[0.10052059, 0.93890846, 0.45351565]],
            #          [[0.47507706, 0.45021373, 0.11331241]]],

            #         [[[0.53358698, 0.97375143, 0.34997326]],
            #          [[0.24758087, 0.52628899, 0.17970420]]]])

C
cnn 已提交
978
            m = paddle.nn.Dropout2D(p=0.5)
979
            y_train = m(x)
980 981 982 983 984 985 986 987
            print(y_train)
            # Tensor(shape=[2, 2, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[0.        , 0.        , 0.        ]],
            #          [[0.95015413, 0.90042746, 0.22662482]]],

            #         [[[1.06717396, 1.94750285, 0.69994652]],
            #          [[0.        , 0.        , 0.        ]]]])

988 989
            m.eval()  # switch the model to test phase
            y_test = m(x)
990
            print(y_test)
991 992 993 994 995 996
            # Tensor(shape=[2, 2, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[0.10052059, 0.93890846, 0.45351565]],
            #          [[0.47507706, 0.45021373, 0.11331241]]],

            #         [[[0.53358698, 0.97375143, 0.34997326]],
            #          [[0.24758087, 0.52628899, 0.17970420]]]])
997
    """
998 999

    def __init__(self, p=0.5, data_format='NCHW', name=None):
1000
        super().__init__()
1001 1002 1003 1004 1005 1006

        self.p = p
        self.data_format = data_format
        self.name = name

    def forward(self, input):
1007 1008 1009 1010 1011 1012 1013
        out = F.dropout2d(
            input,
            p=self.p,
            training=self.training,
            data_format=self.data_format,
            name=self.name,
        )
1014 1015
        return out

1016
    def extra_repr(self):
1017
        name_str = f', name={self.name}' if self.name else ''
1018 1019 1020
        return 'p={}, data_format={}{}'.format(
            self.p, self.data_format, name_str
        )
1021

1022

Z
zhiboniu 已提交
1023
class Dropout3D(Layer):
1024 1025 1026 1027
    """
    Randomly zero out entire channels (in the batched input 5d tensor with the shape `NCDHW` ,
    a channel is a 3D feature map with the shape `DHW` ). Each channel will be zeroed out independently
    on every forward call with probability `p` using samples from a Bernoulli distribution.
C
cnn 已提交
1028
    Dropout3D will help promote independence between feature maps as described in the paper:
T
tangwei12 已提交
1029
    `Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
1030

1031
    See :ref:`api_paddle_nn_functional_dropout3d` for more details.
1032

1033 1034
    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

1035
    Parameters:
1036 1037 1038
        p (float | int, optional): Probability of setting units to zero. Default: 0.5.
        data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from `NCDHW` or `NDHWC`. When it is `NCDHW`, the data is stored in the order of: [batch_size, input_channels, input_depth, input_height, input_width]. Default: `NCDHW`.
        name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
1039 1040 1041 1042 1043

    Shape:
        - input: 5-D tensor.
        - output: 5-D tensor, the same shape as input.

1044

1045 1046
    Examples:
        .. code-block:: python
1047

1048 1049
            import paddle

1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
            x = paddle.arange(24, dtype="float32").reshape((1, 2, 2, 2, 3))
            print(x)
            # Tensor(shape=[1, 2, 2, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[[0. , 1. , 2. ],
            #            [3. , 4. , 5. ]],
            #           [[6. , 7. , 8. ],
            #            [9. , 10., 11.]]],

            #          [[[12., 13., 14.],
            #            [15., 16., 17.]],
            #           [[18., 19., 20.],
            #            [21., 22., 23.]]]]])

C
cnn 已提交
1063
            m = paddle.nn.Dropout3D(p=0.5)
1064
            y_train = m(x)
1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
            print(y_train)
            # Tensor(shape=[1, 2, 2, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[[0. , 2. , 4. ],
            #            [6. , 8. , 10.]],
            #           [[12., 14., 16.],
            #            [18., 20., 22.]]],

            #          [[[0. , 0. , 0. ],
            #            [0. , 0. , 0. ]],
            #           [[0. , 0. , 0. ],
            #            [0. , 0. , 0. ]]]]])

1077 1078
            m.eval()  # switch the model to test phase
            y_test = m(x)
1079
            print(y_test)
1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
            # Tensor(shape=[1, 2, 2, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[[0. , 1. , 2. ],
            #            [3. , 4. , 5. ]],
            #           [[6. , 7. , 8. ],
            #            [9. , 10., 11.]]],

            #          [[[12., 13., 14.],
            #            [15., 16., 17.]],
            #           [[18., 19., 20.],
            #            [21., 22., 23.]]]]])
1090
    """
1091 1092

    def __init__(self, p=0.5, data_format='NCDHW', name=None):
1093
        super().__init__()
1094 1095 1096 1097 1098 1099

        self.p = p
        self.data_format = data_format
        self.name = name

    def forward(self, input):
1100 1101 1102 1103 1104 1105 1106
        out = F.dropout3d(
            input,
            p=self.p,
            training=self.training,
            data_format=self.data_format,
            name=self.name,
        )
1107 1108
        return out

1109
    def extra_repr(self):
1110
        name_str = f', name={self.name}' if self.name else ''
1111 1112 1113
        return 'p={}, data_format={}{}'.format(
            self.p, self.data_format, name_str
        )
1114

1115

Z
zhiboniu 已提交
1116
class AlphaDropout(Layer):
1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137
    """
    Alpha Dropout is a type of Dropout that maintains the self-normalizing property. For an input with
    zero mean and unit standard deviation, the output of Alpha Dropout maintains the original mean and
    standard deviation of the input. Alpha Dropout fits well to SELU activate function by randomly setting
    activations to the negative saturation value.

    For more information, please refer to:
    `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_

    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

    Parameters:
        p (float | int): Probability of setting units to zero. Default: 0.5
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input: N-D tensor.
        - output: N-D tensor, the same shape as input.

    Examples:
        .. code-block:: python
1138

1139 1140
            import paddle

1141
            x = paddle.to_tensor([[-1, 1], [-1, 1]], dtype="float32")
1142 1143
            m = paddle.nn.AlphaDropout(p=0.5)
            y_train = m(x)
1144 1145 1146 1147 1148
            print(y_train)
            # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[-0.77919382,  1.66559887],
            #         [-0.77919382, -0.77919382]])

1149 1150
            m.eval()  # switch the model to test phase
            y_test = m(x)
1151
            print(y_test)
1152 1153 1154
            # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[-1.,  1.],
            #         [-1.,  1.]])
1155
    """
1156 1157

    def __init__(self, p=0.5, name=None):
1158
        super().__init__()
1159 1160 1161 1162
        self.p = p
        self.name = name

    def forward(self, input):
1163 1164 1165
        out = F.alpha_dropout(
            input, p=self.p, training=self.training, name=self.name
        )
1166 1167
        return out

1168
    def extra_repr(self):
1169 1170
        name_str = f', name={self.name}' if self.name else ''
        return f'p={self.p}{name_str}'
1171

1172

Z
zhiboniu 已提交
1173
class Pad1D(Layer):
L
littletomatodonkey 已提交
1174
    """
L
littletomatodonkey 已提交
1175
    This interface is used to construct a callable object of the ``Pad1D`` class.
1176 1177
    Pad tensor according to ``pad``, ``mode`` and ``value``.
    If mode is ``reflect``, pad[0] and pad[1] must be no greater than width-1.
L
littletomatodonkey 已提交
1178 1179

    Parameters:
1180
        padding (Tensor|list[int]|int): The padding size with data type ``'int'``. If is ``'int'``, use the
1181
            same padding in both dimensions. Else [len(padding)/2] dimensions
L
littletomatodonkey 已提交
1182
            of input will be padded. The pad has the form (pad_left, pad_right).
1183
        mode (str, optional): Four modes: ``'constant'`` (default), ``'reflect'``, ``'replicate'``, ``'circular'``. Default: ``'constant'``.
1184 1185 1186 1187 1188 1189

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

1190 1191 1192 1193
        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`.
        data_format (str, optional): An string from: ``'NCL'``, ``'NLC'``. Specify the data format of the input data.
           Default: ``'NCL'``.
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: ``'None'``.
1194 1195

    Returns:
L
littletomatodonkey 已提交
1196 1197 1198 1199
        None

    Examples:
        .. code-block:: python
1200

L
littletomatodonkey 已提交
1201 1202 1203 1204 1205
            import paddle
            import paddle.nn as nn

            input_shape = (1, 2, 3)
            pad = [1, 2]
L
littletomatodonkey 已提交
1206
            mode = "constant"
1207
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1208
            my_pad = nn.Pad1D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1209
            result = my_pad(data)
L
littletomatodonkey 已提交
1210
            print(result)
L
littletomatodonkey 已提交
1211 1212 1213 1214
            # [[[0. 1. 2. 3. 0. 0.]
            #   [0. 4. 5. 6. 0. 0.]]]
    """

1215 1216 1217
    def __init__(
        self, padding, mode='constant', value=0.0, data_format="NCL", name=None
    ):
1218
        super().__init__()
1219
        self._pad = _npairs(padding, 1)
L
littletomatodonkey 已提交
1220
        self._mode = mode
L
littletomatodonkey 已提交
1221
        self._value = value
L
littletomatodonkey 已提交
1222
        self._data_format = data_format
L
littletomatodonkey 已提交
1223 1224 1225
        self._name = name

    def forward(self, x):
1226 1227 1228 1229 1230 1231 1232 1233
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
L
littletomatodonkey 已提交
1234

1235
    def extra_repr(self):
1236
        name_str = f', name={self._name}' if self._name else ''
1237
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
1238 1239
            self._pad, self._mode, self._value, self._data_format, name_str
        )
1240

L
littletomatodonkey 已提交
1241

Z
zhiboniu 已提交
1242
class Pad2D(Layer):
L
littletomatodonkey 已提交
1243
    """
L
littletomatodonkey 已提交
1244
    This interface is used to construct a callable object of the ``Pad2D`` class.
1245 1246
    Pad tensor according to ``pad``, ``mode`` and ``value``.
    If mode is ``'reflect'``, pad[0] and pad[1] must be no greater
L
littletomatodonkey 已提交
1247
    than width-1. The height dimension has the same condition.
L
littletomatodonkey 已提交
1248 1249

    Parameters:
1250
        padding (Tensor|list[int]|int): The padding size with data type ``'int'``. If is ``'int'``, use the
1251 1252
            same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
            The pad has the form (pad_left, pad_right, pad_top, pad_bottom).
1253
        mode (str, optional): Four modes: ``'constant'`` (default), ``'reflect'``, ``'replicate'``, ``'circular'``. Default: ``'constant'``.
1254 1255 1256 1257 1258 1259

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

1260 1261 1262 1263
        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`.
        data_format (str, optional): An string from: ``'NCHW'``, ``'NHWC'``. Specify the data format of the input data.
           Default: ``'NCHW'``.
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: ``'None'``.
1264 1265

    Returns:
L
littletomatodonkey 已提交
1266 1267 1268 1269
        None

    Examples:
        .. code-block:: python
1270

L
littletomatodonkey 已提交
1271 1272
            import paddle
            import paddle.nn as nn
1273

L
littletomatodonkey 已提交
1274 1275
            input_shape = (1, 1, 2, 3)
            pad = [1, 0, 1, 2]
L
littletomatodonkey 已提交
1276
            mode = "constant"
1277
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1278
            my_pad = nn.Pad2D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1279
            result = my_pad(data)
L
littletomatodonkey 已提交
1280
            print(result)
L
littletomatodonkey 已提交
1281 1282 1283 1284 1285 1286 1287
            # [[[[0. 0. 0. 0.]
            #    [0. 1. 2. 3.]
            #    [0. 4. 5. 6.]
            #    [0. 0. 0. 0.]
            #    [0. 0. 0. 0.]]]]
    """

1288 1289 1290
    def __init__(
        self, padding, mode='constant', value=0.0, data_format="NCHW", name=None
    ):
1291
        super().__init__()
1292
        self._pad = _npairs(padding, 2)
L
littletomatodonkey 已提交
1293
        self._mode = mode
L
littletomatodonkey 已提交
1294 1295 1296 1297 1298
        self._value = value
        self._data_format = data_format
        self._name = name

    def forward(self, x):
1299 1300 1301 1302 1303 1304 1305 1306
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
L
littletomatodonkey 已提交
1307

1308
    def extra_repr(self):
1309
        name_str = f', name={self._name}' if self._name else ''
1310
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
1311 1312
            self._pad, self._mode, self._value, self._data_format, name_str
        )
1313

L
littletomatodonkey 已提交
1314

1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
class ZeroPad2D(Layer):
    """
    This interface is used to construct a callable object of the ``ZeroPad2D`` class.
    Pads the input tensor boundaries with zero.

    Parameters:
        padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
            same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
            The pad has the form (pad_left, pad_right, pad_top, pad_bottom).
        data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data.
           Default is  "NCHW"
        name (str, optional) : The default value is None.  Normally there is no need for
            user to set this property.  For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - x(Tensor): The input tensor of zeropad2d operator, which is a 4-D tensor.
          The data type can be float32, float64.
        - output(Tensor): The output tensor of zeropad2d operator, which is a 4-D tensor.
          The data type is same as input x.

    Examples:
        Examples are as follows.

        .. code-block:: python

            import paddle
            import paddle.nn as nn

K
knamg 已提交
1343
            input_shape = paddle.to_tensor([1, 1, 2, 3])
1344
            pad = [1, 0, 1, 2]
K
knamg 已提交
1345
            data = paddle.arange(paddle.prod(input_shape), dtype="float32").reshape(input_shape) + 1
1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358

            my_pad = nn.ZeroPad2D(padding=pad)
            result = my_pad(data)

            print(result)
            # [[[[0. 0. 0. 0.]
            #    [0. 1. 2. 3.]
            #    [0. 4. 5. 6.]
            #    [0. 0. 0. 0.]
            #    [0. 0. 0. 0.]]]]
    """

    def __init__(self, padding, data_format="NCHW", name=None):
1359
        super().__init__()
1360 1361
        self._pad = _npairs(padding, 2)
        self._mode = 'constant'
1362
        self._value = 0.0
1363 1364 1365 1366
        self._data_format = data_format
        self._name = name

    def forward(self, x):
1367 1368 1369 1370 1371 1372 1373 1374
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
1375 1376

    def extra_repr(self):
1377
        name_str = f', name={self._name}' if self._name else ''
1378 1379 1380
        return 'padding={}, data_format={}{}'.format(
            self._pad, self._data_format, name_str
        )
1381 1382


Z
zhiboniu 已提交
1383
class Pad3D(Layer):
L
littletomatodonkey 已提交
1384
    """
L
littletomatodonkey 已提交
1385
    This interface is used to construct a callable object of the ``Pad3D`` class.
1386 1387
    Pad tensor according to ``'pad'``, ``'mode'`` and ``'value'``.
    If mode is ``'reflect'``, pad[0] and pad[1] must be no greater
L
littletomatodonkey 已提交
1388
    than width-1. The height and depth dimension has the same condition.
L
littletomatodonkey 已提交
1389 1390

    Parameters:
1391
        padding (Tensor|list[int]|int): The padding size with data type ``'int'``. If is ``'int'``, use the
1392
            same padding in all dimensions. Else [len(padding)/2] dimensions
L
littletomatodonkey 已提交
1393
            of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back).
1394
        mode (str, optional): Four modes: ``'constant'`` (default), ``'reflect'``, ``'replicate'``, ``'circular'``. Default: ``'constant'``.
1395 1396 1397 1398 1399 1400

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

1401 1402 1403 1404
        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`.
        data_format (str, optional): An string from: ``'NCDHW'``, ``'NDHWC'``. Specify the data format of the input data.
           Default:  ``'NCDHW'``。
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: ``'None'``.
1405 1406

    Returns:
L
littletomatodonkey 已提交
1407 1408 1409 1410
        None

    Examples:
        .. code-block:: python
1411

L
littletomatodonkey 已提交
1412 1413
            import paddle
            import paddle.nn as nn
1414

L
littletomatodonkey 已提交
1415 1416
            input_shape = (1, 1, 1, 2, 3)
            pad = [1, 0, 1, 2, 0, 0]
L
littletomatodonkey 已提交
1417
            mode = "constant"
1418
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1419
            my_pad = nn.Pad3D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1420
            result = my_pad(data)
L
littletomatodonkey 已提交
1421
            print(result)
L
littletomatodonkey 已提交
1422 1423 1424 1425 1426 1427 1428
            # [[[[[0. 0. 0. 0.]
            #     [0. 1. 2. 3.]
            #     [0. 4. 5. 6.]
            #     [0. 0. 0. 0.]
            #     [0. 0. 0. 0.]]]]]
    """

1429 1430 1431 1432 1433 1434 1435 1436
    def __init__(
        self,
        padding,
        mode='constant',
        value=0.0,
        data_format="NCDHW",
        name=None,
    ):
1437
        super().__init__()
1438
        self._pad = _npairs(padding, 3)
L
littletomatodonkey 已提交
1439
        self._mode = mode
L
littletomatodonkey 已提交
1440 1441 1442 1443 1444
        self._value = value
        self._data_format = data_format
        self._name = name

    def forward(self, x):
1445 1446 1447 1448 1449 1450 1451 1452
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
L
littletomatodonkey 已提交
1453

1454
    def extra_repr(self):
1455
        name_str = f', name={self._name}' if self._name else ''
1456
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
1457 1458
            self._pad, self._mode, self._value, self._data_format, name_str
        )
1459

L
littletomatodonkey 已提交
1460

Z
zhiboniu 已提交
1461
class CosineSimilarity(Layer):
L
littletomatodonkey 已提交
1462
    """
1463
    This interface is used to compute cosine similarity between x1 and x2 along axis.
L
littletomatodonkey 已提交
1464 1465

    Parameters:
1466
        axis (int): Dimension of vectors to compute cosine similarity. Default is 1.
L
littletomatodonkey 已提交
1467
        eps(float): Small value to avoid division by zero. Default is 1e-8.
1468
    Returns:
L
littletomatodonkey 已提交
1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482
        None

    Examples:
        .. code-block:: text

            Case 0:
                x1 = [[0.8024077  0.9927354  0.27238318 0.8344984 ]
                     [0.48949873 0.5797396  0.65444374 0.66510963]
                     [0.1031398  0.9614342  0.08365563 0.6796464 ]
                     [0.10760343 0.7461209  0.7726148  0.5801006 ]]
                x2 = [[0.62913156 0.1536727  0.9847992  0.04591406]
                     [0.9098952  0.15715368 0.8671125  0.3156102 ]
                     [0.4427798  0.54136837 0.5276275  0.32394758]
                     [0.3769419  0.8535014  0.48041078 0.9256797 ]]
1483
                axis = 1
L
littletomatodonkey 已提交
1484 1485 1486 1487 1488
                eps = 1e-8
                Out: [0.5275037  0.8368967  0.75037485 0.9245899]

    Code Examples:
        .. code-block:: python
1489

L
littletomatodonkey 已提交
1490 1491 1492
            import paddle
            import paddle.nn as nn

1493 1494 1495 1496
            x1 = paddle.to_tensor([[1., 2., 3.],
                                [2., 3., 4.]], dtype="float32")
            x2 = paddle.to_tensor([[8., 3., 3.],
                                [2., 3., 4.]], dtype="float32")
L
littletomatodonkey 已提交
1497

1498
            cos_sim_func = nn.CosineSimilarity(axis=0)
L
littletomatodonkey 已提交
1499
            result = cos_sim_func(x1, x2)
L
littletomatodonkey 已提交
1500
            print(result)
1501 1502
            # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [0.65079135, 0.98058069, 1.        ])
L
littletomatodonkey 已提交
1503 1504
    """

1505
    def __init__(self, axis=1, eps=1e-8):
1506
        super().__init__()
1507
        self._axis = axis
L
littletomatodonkey 已提交
1508 1509 1510
        self._eps = eps

    def forward(self, x1, x2):
1511
        return F.cosine_similarity(x1, x2, axis=self._axis, eps=self._eps)
T
tangwei12 已提交
1512

1513 1514 1515
    def extra_repr(self):
        return 'axis={_axis}, eps={_eps}'.format(**self.__dict__)

T
tangwei12 已提交
1516

Z
zhiboniu 已提交
1517
class Embedding(Layer):
1518
    r"""
1519

1520
    Embedding Layer, used to construct a callable object of the ``Embedding`` class.
T
tangwei12 已提交
1521
    For specific usage, refer to code examples. It implements the function of the Embedding Layer.
T
tangwei12 已提交
1522
    This layer is used to lookup embeddings vector of ids provided by :attr:`x` .
T
tangwei12 已提交
1523
    It automatically constructs a 2D embedding matrix based on the
T
tangwei12 已提交
1524
    input :attr:`num_embeddings` and :attr:`embedding_dim`.
T
tangwei12 已提交
1525 1526 1527 1528

    The shape of output Tensor is generated by appending an emb_size dimension to the
    last dimension of the input Tensor shape.

1529 1530 1531
    Note:
        The id in :attr:`x` must satisfy :math:`0 =< id < num_embeddings` ,
        otherwise the program will throw an exception and exit.
T
tangwei12 已提交
1532 1533 1534 1535 1536

    .. code-block:: text

        Case 1:

T
tangwei12 已提交
1537 1538 1539
        x is a Tensor. padding_idx = -1
            x.data = [[1, 3], [2, 4], [4, 127]
            x.shape = [3, 2]
T
tangwei12 已提交
1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556
        Given size = [128, 16]
        output is a Tensor:
            out.shape = [3, 2, 16]
            out.data = [[[0.129435295, 0.244512452, ..., 0.436322452],
                        [0.345421456, 0.524563927, ..., 0.144534654]],

                        [[0.345249859, 0.124939536, ..., 0.194353745],
                        [0.945345345, 0.435394634, ..., 0.435345365]],

                        [[0.945345345, 0.435394634, ..., 0.435345365],
                        [0.0,         0.0,         ..., 0.0        ]]]  # padding data
        The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127
        It will pad all-zero data when ids is 127.

    Parameters:
        num_embeddings (int): Just one element which indicate the size
            of the dictionary of embeddings.
T
tangwei12 已提交
1557
        embedding_dim (int):  Just one element which indicate the size of each embedding vector respectively.
1558
        padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-num_embeddings, num_embeddings).
T
tangwei12 已提交
1559 1560 1561 1562
            If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
            to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
            encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
            If set None, it makes no effect to output. Default: None.
1563
        sparse(bool, optional): The flag indicating whether to use sparse update. This parameter only
T
tangwei12 已提交
1564 1565
            affects the performance of the backwards gradient update. It is recommended to set
            True because sparse update is faster. But some optimizer does not support sparse update,
T
tangwei12 已提交
1566
            such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`.
T
tangwei12 已提交
1567
            In these case, sparse must be False. Default: False.
1568
        weight_attr(ParamAttr, optional): To specify the weight parameter property. Default: None, which means the
T
tangwei12 已提交
1569
            default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
T
tangwei12 已提交
1570 1571
            user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
            The local word vector needs to be transformed into numpy format, and the shape of local word
T
tangwei12 已提交
1572 1573
            vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
            is used to load custom or pre-trained word vectors. See code example for details.
1574
        name(str|None, optional): For detailed information, please refer
T
tangwei12 已提交
1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587
               to :ref:`api_guide_Name`. Usually name is no need to set and
               None by default.

    Attribute:
        **weight** (Parameter): the learnable weights of this layer.

    Returns:
        None

    Examples:

        .. code-block:: python

T
tangwei12 已提交
1588 1589
            import paddle

1590 1591
            x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False)
            embedding = paddle.nn.Embedding(4, 3, sparse=True)
T
tangwei12 已提交
1592

1593 1594 1595 1596
            w0 = paddle.to_tensor([[0., 0., 0.],
                                [1., 1., 1.],
                                [2., 2., 2.],
                                [3., 3., 3.]], dtype="float32")
T
tangwei12 已提交
1597
            embedding.weight.set_value(w0)
1598 1599 1600 1601 1602 1603
            print(embedding.weight)
            # Tensor(shape=[4, 3], dtype=float32, place=Place(gpu:0), stop_gradient=False,
            #        [[0., 0., 0.],
            #         [1., 1., 1.],
            #         [2., 2., 2.],
            #         [3., 3., 3.]])
T
tangwei12 已提交
1604

T
tangwei12 已提交
1605 1606 1607 1608
            adam = paddle.optimizer.Adam(parameters=[embedding.weight], learning_rate=0.01)
            adam.clear_grad()


1609 1610 1611 1612 1613 1614
            out = embedding(x)
            print(out)
            # Tensor(shape=[3, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=False,
            #        [[[0., 0., 0.]],
            #         [[1., 1., 1.]],
            #         [[3., 3., 3.]]])
T
tangwei12 已提交
1615 1616 1617

            out.backward()
            adam.step()
T
tangwei12 已提交
1618 1619 1620

    """

1621 1622 1623 1624 1625 1626 1627 1628 1629
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        padding_idx=None,
        sparse=False,
        weight_attr=None,
        name=None,
    ):
1630
        super().__init__()
T
tangwei12 已提交
1631 1632 1633 1634
        self._num_embeddings = num_embeddings
        self._embedding_dim = embedding_dim
        self._sparse = sparse
        self._is_distributed = False
1635
        self._padding_idx = padding_idx
T
tangwei12 已提交
1636 1637 1638 1639 1640 1641 1642

        if self._num_embeddings <= 0:
            raise ValueError("num_embeddings must be gather than 0")

        if self._embedding_dim <= 0:
            raise ValueError("embedding_dim must be gather than 0")

1643 1644 1645 1646 1647 1648 1649
        padding_idx = (
            -1
            if padding_idx is None
            else padding_idx
            if padding_idx >= 0
            else (num_embeddings + padding_idx)
        )
1650 1651

        if padding_idx >= num_embeddings or padding_idx < -num_embeddings:
1652 1653 1654 1655 1656
            raise ValueError(
                "padding_idx must be within [-{}, {})".format(
                    num_embeddings, num_embeddings
                )
            )
T
tangwei12 已提交
1657

T
tangwei12 已提交
1658 1659 1660 1661 1662 1663
        self._dtype = self._helper.get_default_dtype()
        self._size = [self._num_embeddings, self._embedding_dim]

        self._weight_attr = weight_attr
        self._remote_prefetch = False
        self._name = name
1664 1665 1666 1667 1668 1669
        self.weight = self.create_parameter(
            attr=self._weight_attr,
            shape=self._size,
            dtype=self._dtype,
            is_bias=False,
        )
T
tangwei12 已提交
1670

Z
zhiboniu 已提交
1671
        if in_dynamic_mode() and padding_idx != -1:
1672 1673
            with paddle.no_grad():
                self.weight[padding_idx] = 0.0
T
tangwei12 已提交
1674

T
tangwei12 已提交
1675
    def forward(self, x):
1676 1677 1678 1679 1680 1681 1682
        return F.embedding(
            x,
            weight=self.weight,
            padding_idx=self._padding_idx,
            sparse=self._sparse,
            name=self._name,
        )
1683 1684 1685 1686 1687 1688 1689 1690 1691

    def extra_repr(self):
        main_str = '{_num_embeddings}, {_embedding_dim}'
        if self._padding_idx is not None:
            main_str += ', padding_idx={_padding_idx}'
        main_str += ', sparse={_sparse}'
        if self._name is not None:
            main_str += ', name={_name}'
        return main_str.format(**self.__dict__)
F
FNRE 已提交
1692 1693


Z
zhiboniu 已提交
1694
class Unfold(Layer):
F
FNRE 已提交
1695
    """
1696
    Returns a col buffer of sliding local blocks of input x, also known
F
FNRE 已提交
1697 1698 1699 1700 1701 1702 1703 1704 1705
    as im2col for batched 2D image tensors. For each block under the convolution filter,
    all element will be rearranged as a column. While the convolution filter sliding over
    the input feature map, a series of such columns will be formed.

    For each input :math:`x` with shape [N, C, H, W], the output shape [N, Cout, Lout]
    can be calculated as following.

    See ``paddle.nn.functional.unfold`` for more details.

1706

F
FNRE 已提交
1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737
    Parameters:
        kernel_sizes(int|list):   The size of convolution kernel, should be [k_h, k_w]
                                  or an integer k treated as [k, k].
        strides(int|list):        The strides, should be [stride_h, stride_w]
                                  or an integer stride treated as [sride, stride].
                                  For default, strides will be [1, 1].
        paddings(int|list):       The paddings of each dimension, should be
                                  [padding_top, padding_left, padding_bottom, padding_right]
                                  or [padding_h, padding_w] or an integer padding.
                                  If [padding_h, padding_w] was given, it will expanded to
                                  [padding_h, padding_w, padding_h, padding_w]. If an integer
                                  padding was given, [padding, padding, padding, padding] will
                                  be used. For default, paddings will be [0, 0, 0, 0]
        dilations(int|list):      the dilations of convolution kernel, should be
                                  [dilation_h, dilation_w], or an integer dilation treated as
                                  [dilation, dilation]. For default, it will be [1, 1].
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`


    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            x = paddle.randn((100,3,224,224))
            unfold = nn.Unfold(kernel_sizes=[3, 3])
            result = unfold(x)
            print(result)
X
xiaoting 已提交
1738
    """
F
FNRE 已提交
1739

1740 1741 1742
    def __init__(
        self, kernel_sizes, dilations=1, paddings=0, strides=1, name=None
    ):
1743
        super().__init__()
F
FNRE 已提交
1744 1745 1746 1747 1748 1749 1750 1751

        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.paddings = paddings
        self.strides = strides
        self.name = name

    def forward(self, input):
1752 1753 1754 1755 1756 1757 1758 1759
        return F.unfold(
            input,
            kernel_sizes=self.kernel_sizes,
            strides=self.strides,
            paddings=self.paddings,
            dilations=self.dilations,
            name=self.name,
        )
F
FNRE 已提交
1760 1761

    def extra_repr(self):
1762
        name_str = f', name={self.name}' if self.name else ''
1763 1764 1765 1766 1767 1768 1769
        return 'kernel_size={}, dilation={}, padding={}, stride={}{}'.format(
            self.kernel_sizes,
            self.dilations,
            self.paddings,
            self.strides,
            name_str,
        )
X
xiaoting 已提交
1770 1771 1772


class Fold(Layer):
1773
    r"""
X
xiaoting 已提交
1774

1775
    Combines an array of sliding local blocks into a large containing
1776 1777
    tensor. also known as col2im when operated on batched 2D image tensor. Fold calculates each
    combined value in the resulting large tensor by summing all values from all containing blocks.
X
xiaoting 已提交
1778 1779 1780 1781 1782 1783


    For each input :math:`x` with shape [N, C_in , L], the output shape [N, C_out, H_out, W_out]
    can be calculated as following.

    .. math::
1784

1785 1786 1787
        H_{out} &= output\_size[0] \\
        W_{out} &= output\_size[1] \\
        C_{out} &= \frac{C_{in}}{kernel\_sizes[0]\times kernel\_sizes[1]} \\
X
xiaoting 已提交
1788 1789 1790 1791

    Parameters:
        output_sizes(list):       The size of output size, should be [output_size_h, output_size_w]
                                  or an interger o treated as [o, o].
X
xiaoting 已提交
1792
        kernel_sizes(int|list|tuple):   The size of convolution kernel, should be [k_h, k_w]
X
xiaoting 已提交
1793
                                  or an integer k treated as [k, k].
1794
        strides(int|list|tuple, optional):        The strides, should be [stride_h, stride_w]
X
xiaoting 已提交
1795 1796
                                  or an integer stride treated as [sride, stride].
                                  For default, strides will be [1, 1].
1797
        paddings(int|list|tuple, optional):       The paddings of each dimension, should be
X
xiaoting 已提交
1798 1799 1800 1801 1802 1803
                                  [padding_top, padding_left, padding_bottom, padding_right]
                                  or [padding_h, padding_w] or an integer padding.
                                  If [padding_h, padding_w] was given, it will expanded to
                                  [padding_h, padding_w, padding_h, padding_w]. If an integer
                                  padding was given, [padding, padding, padding, padding] will
                                  be used. For default, paddings will be [0, 0, 0, 0]
1804
        dilations(int|list|tuple, optional):      the dilations of convolution kernel, should be
X
xiaoting 已提交
1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822
                                  [dilation_h, dilation_w], or an integer dilation treated as
                                  [dilation, dilation]. For default, it will be [1, 1].
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`


    Returns:
        The tensor formed by combining a group of sliding local blocks
        The output shape is [N, Cout, H, W] as decriabled above.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
1823 1824
            x = paddle.randn([2,3*2*2,12])
            fold = nn.Fold(output_sizes=[4, 5], kernel_sizes=2)
X
xiaoting 已提交
1825
            y = fold(x)
X
xiaoting 已提交
1826
            # y.shape = [2,3,4,5]
X
xiaoting 已提交
1827 1828
   """

1829 1830 1831 1832 1833 1834 1835 1836 1837
    def __init__(
        self,
        output_sizes,
        kernel_sizes,
        dilations=1,
        paddings=0,
        strides=1,
        name=None,
    ):
1838
        super().__init__()
X
xiaoting 已提交
1839 1840 1841 1842 1843 1844 1845 1846 1847

        self.output_sizes = output_sizes
        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.paddings = paddings
        self.strides = strides
        self.name = name

    def forward(self, input):
1848 1849 1850 1851 1852 1853 1854 1855 1856
        return F.fold(
            input,
            output_sizes=self.output_sizes,
            kernel_sizes=self.kernel_sizes,
            strides=self.strides,
            paddings=self.paddings,
            dilations=self.dilations,
            name=self.name,
        )
X
xiaoting 已提交
1857 1858

    def extra_repr(self):
1859
        name_str = f', name={self.name}' if self.name else ''
1860 1861 1862 1863 1864 1865 1866
        return 'kernel_size={}, dilation={}, padding={}, stride={}{}'.format(
            self.kernel_sizes,
            self.dilations,
            self.paddings,
            self.strides,
            name_str,
        )
1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899


class Flatten(Layer):
    """
    This interface is used to construct a callable object of the ``FLatten`` class.
    For more details, refer to code examples.
    It implements flatten a contiguous range of dims into a tensor.

    Parameters:
        start_axis(int): first dim to flatten (default = 1)
        stop_axis(int): last dim to flatten (default = -1).

    Returns:
        None

    Examples:

        .. code-block:: python

          import paddle

          inp = paddle.ones([5, 2, 3, 4]).astype('float32')
          flatten = paddle.nn.Flatten(start_axis=1, stop_axis=2)
          y = flatten(inp)
          # y.shape = [5, 6, 4]

    """

    def __init__(self, start_axis=1, stop_axis=-1):
        super().__init__()
        self.start_axis = start_axis
        self.stop_axis = stop_axis

1900
    def forward(self, input):
1901
        out = paddle.flatten(
1902
            input, start_axis=self.start_axis, stop_axis=self.stop_axis
1903 1904
        )
        return out
1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954


class Unflatten(Layer):
    """
    This interface is used to construct a callable object of the ``Unflatten`` class.
    For more details, refer to code examples.
    It a certain dimension of the input x Tensor into a desired shape.

    Parameters:
        axis (int): :attr:`axis` to be unflattened, specified as an index into `x.shape`.
        shape (list|tuple|Tensor): Unflatten :attr:`shape` on the specified :attr:`axis`. At most one dimension of the target :attr:`shape` can be -1.
            If the input :attr:`shape` does not contain -1 , the product of all elements in ``shape`` should be equal to ``x.shape[axis]``.
            The data type is `int` . If :attr:`shape` is a list or tuple, the elements of it should be integers or Tensors with shape [].
            If :attr:`shape` is an Tensor, it should be an 1-D Tensor.
        name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

    Returns:
        None

    Examples:

        .. code-block:: python

            import paddle

            x = paddle.randn(shape=[4, 6, 8])
            shape = [2, 3]
            axis = 1
            unflatten = paddle.nn.Unflatten(axis, shape)
            res = unflatten(x)
            print(res.shape)
            # [4, 2, 3, 8]

    """

    def __init__(self, axis, shape, name=None):
        super().__init__()
        self.axis = axis
        self.shape = shape
        self.name = name

    def forward(self, input):
        out = paddle.unflatten(
            input, axis=self.axis, shape=self.shape, name=self.name
        )
        return out

    def extra_repr(self):
        name_str = f', name={self.name}' if self.name else ''
        return f'axis={self.axis}, shape={self.shape}{name_str}'