common.py 64.6 KB
Newer Older
S
shiyutang 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
S
shiyutang 已提交
7
#    http://www.apache.org/licenses/LICENSE-2.0
8 9 10 11 12 13 14
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# TODO: define the common classes to build a neural network
16
import paddle
Z
zhiboniu 已提交
17
from ...fluid.dygraph import Flatten  # noqa: F401
18
from .. import functional as F
Z
zhiboniu 已提交
19
from paddle.nn import Layer
Z
zhiboniu 已提交
20
from paddle import in_dynamic_mode
21

22 23
__all__ = []

24

25
def _npairs(x, n):
26
    if isinstance(x, (paddle.Tensor, list, tuple)):
27 28 29 30 31
        return x
    x = [x] * (n * 2)
    return x


S
shiyutang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
class Identity(Layer):
    r"""

    A placeholder identity operator that is argument-insensitive. For each input :math:`X` ,
    the output :math:`Out` is:

    .. math::

        Out = X

    Parameters:
        args: any argument (unused)
        kwargs: any keyword argument (unused)

    Shape:
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, n1, n2, ...]` .
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, n1, n2, ...]` .

    Examples:
        .. code-block:: python

          import paddle

          input_tensor = paddle.randn(shape=[3, 2])
          layer = paddle.nn.Identity()
          out = layer(input_tensor)
          # input_tensor: [[-0.32342386 -1.200079  ]
          #                [ 0.7979031  -0.90978354]
          #                [ 0.40597573  1.8095392 ]]
          # out: [[-0.32342386 -1.200079  ]
          #      [ 0.7979031  -0.90978354]
          #      [ 0.40597573  1.8095392 ]]


    """

    def __init__(self, *args, **kwargs):
        super(Identity, self).__init__()

    def forward(self, input):
        return input


Z
zhiboniu 已提交
75
class Linear(Layer):
76
    r"""
77 78 79

    Fully-connected linear transformation layer. For each input :math:`X` ,
    the equation is:
80 81 82

    .. math::

83
        Out = XW + b
84

85
    where :math:`W` is the weight and :math:`b` is the bias.
86

87 88 89 90 91 92 93
    Linear layer takes only one multi-dimensional tensor as input with the
    shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
    number of additional dimensions. It multiplies input tensor with the weight
    (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
    an output tensor of shape :math:`[batch\_size, *, out\_features]` .
    If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
    shape :math:`[out\_features]` ) will be created and added to the output.
94 95

    Parameters:
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        in_features (int): The number of input units.
        out_features (int): The number of output units.
        weight_attr (ParamAttr, optional): The attribute for the learnable
            weight of this layer. The default value is None and the weight will be
            initialized to zero. For detailed information, please refer to
            paddle.ParamAttr.
        bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
            of this layer. If it is set to False, no bias will be added to the output.
            If it is set to None or one kind of ParamAttr, a bias parameter will
            be created according to ParamAttr. For detailed information, please refer
            to paddle.ParamAttr. The default value is None and the bias will be
            initialized to zero.
        name (str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .

    Attribute:
        **weight** (Parameter): the learnable weight of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Shape:
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` .
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` .
119 120 121 122 123

    Examples:
        .. code-block:: python

          import paddle
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

          # Define the linear layer.
          weight_attr = paddle.ParamAttr(
              name="weight",
              initializer=paddle.nn.initializer.Constant(value=0.5))
          bias_attr = paddle.ParamAttr(
              name="bias",
              initializer=paddle.nn.initializer.Constant(value=1.0))
          linear = paddle.nn.Linear(2, 4, weight_attr=weight_attr, bias_attr=bias_attr)
          # linear.weight: [[0.5 0.5 0.5 0.5]
          #                 [0.5 0.5 0.5 0.5]]
          # linear.bias: [1. 1. 1. 1.]

          x = paddle.randn((3, 2), dtype="float32")
          # x: [[-0.32342386 -1.200079  ]
          #     [ 0.7979031  -0.90978354]
          #     [ 0.40597573  1.8095392 ]]
          y = linear(x)
          # y: [[0.23824859 0.23824859 0.23824859 0.23824859]
          #     [0.9440598  0.9440598  0.9440598  0.9440598 ]
          #     [2.1077576  2.1077576  2.1077576  2.1077576 ]]
145 146 147 148 149 150 151 152 153 154 155 156
    """

    def __init__(self,
                 in_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(Linear, self).__init__()
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
157 158 159 160 161 162 163 164
        self.weight = self.create_parameter(shape=[in_features, out_features],
                                            attr=self._weight_attr,
                                            dtype=self._dtype,
                                            is_bias=False)
        self.bias = self.create_parameter(shape=[out_features],
                                          attr=self._bias_attr,
                                          dtype=self._dtype,
                                          is_bias=True)
165 166 167
        self.name = name

    def forward(self, input):
168 169 170 171
        out = F.linear(x=input,
                       weight=self.weight,
                       bias=self.bias,
                       name=self.name)
172 173
        return out

174 175 176 177 178
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str)

179

Z
zhiboniu 已提交
180
class Upsample(Layer):
181 182
    """
    This op resizes a batch of images.
183

184 185 186
    The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
    or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
    (num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
187 188
    Where in_w is width of the input tensor, in_h is the height of the input tensor,
    in_d is the depth of the intput tensor.
189
    and the resizing only applies on the three dimensions(depth, height and width).
X
xiaoting 已提交
190

191
    Supporting resample methods:
192 193 194 195 196 197
        'linear' : Linear interpolation
        'bilinear' : Bilinear interpolation
        'trilinear' : Trilinear interpolation
        'nearest' : Nearest neighbor interpolation
        'bicubic' : Bicubic interpolation

T
tangwei12 已提交
198 199 200
    Linear interpolation is the method of using a line connecting two known quantities
    to determine the value of an unknown quantity between the two known quantities.

201 202 203 204 205 206 207 208 209
    Nearest neighbor interpolation is to perform nearest neighbor interpolation
    in both the 3rd dimension(in height direction) and the 4th dimension(in width
    direction) on input tensor.

    Bilinear interpolation is an extension of linear interpolation for
    interpolating functions of two variables (e.g. H-direction and
    W-direction in this op) on a rectilinear 2D grid. The key idea is
    to perform linear interpolation first in one direction, and then
    again in the other direction.
T
tangwei12 已提交
210

211 212 213 214
    Bicubic interpolation is an extension of cubic interpolation for interpolating
    data points on a two-dimensional regular grid. The interpolated surface is
    smoother than corresponding surfaces obtained by bilinear interpolation or
    nearest-neighbor interpolation.
215 216 217 218 219

    Trilinear interpolation is an extension of linear interpolation for
    interpolating functions of three variables (e.g. D-direction,
    H-direction and W-direction in this op) on a rectilinear 3D grid.
    The linear interpolation is performed on three directions.
X
xiaoting 已提交
220
    align_corners and align_mode are optional parameters,the calculation method
221 222
    of interpolation can be selected by them.

223 224 225 226 227 228
    Area interpolation is to perform area interpolation
    in both the 3rd dimension(in height direction) , the 4th dimension(in width
    direction) and the 5th dimension(in depth direction) on input tensor. Set to
    area will directly call `paddle.nn.functional.adaptive_avg_pool1d` or
    `paddle.nn.functional.adaptive_avg_pool2d` or `paddle.nn.functional.adaptive_avg_pool3d`.

229 230 231 232
    Example:

    .. code-block:: text

233
        For scale_factor:
234 235 236 237 238
            if align_corners = True && out_size > 1 :
              scale_factor = (in_size-1.0)/(out_size-1.0)
            else:
              scale_factor = float(in_size/out_size)

239 240 241 242 243 244 245 246 247 248
        Linear interpolation:
            if:
                align_corners = False , align_mode = 0
                input : (N,C,W_in)
                output: (N,C,W_out) where:
                W_out = (W_{in}+0.5) * scale_{factor} - 0.5
            else:
                input : (N,C,W_in)
                output: (N,C,W_out) where:
                W_out = W_{in} * scale_{factor}
249 250 251 252 253 254 255 256 257 258 259 260 261 262

        Nearest neighbor interpolation:
          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = floor (H_{in} * scale_{factor})
              W_out = floor (W_{in} * scale_{factor})
          else:
              align_corners = True
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})
T
tangwei12 已提交
263

264 265 266
        Bilinear interpolation:
          if:
              align_corners = False , align_mode = 0
267

268 269 270 271 272
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
273

274 275 276 277 278 279 280 281 282 283 284 285
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

        Bicubic interpolation:
          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
286

287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
          else:
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

        Trilinear interpolation:
          if:
              align_corners = False , align_mode = 0
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

308 309
    https://en.wikipedia.org/wiki/Linear_interpolation.
    For details of linear interpolation, please refer to Wikipedia:
T
tangwei12 已提交
310

311 312
    For details of nearest neighbor interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
T
tangwei12 已提交
313

314 315
    For details of bilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bilinear_interpolation.
T
tangwei12 已提交
316

317 318
    For details of bicubic interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bicubic_interpolation
T
tangwei12 已提交
319

320 321
    For details of trilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Trilinear_interpolation.
T
tangwei12 已提交
322

323
    Parameters:
X
xiaoting 已提交
324
        x (Tensor): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
325
                          its data format is specified by :attr:`data_format`.
X
xiaoting 已提交
326
        size (list|tuple|Tensor|None): Output shape of image resize
327 328
             layer, the shape is (out_w, ) when input is a 3-D Tensor, the shape is (out_h, out_w)
             when input is a 4-D Tensor and is (out_d, out_h, out_w) when input is a 5-D Tensor.
329
             Default: None. If a list/tuple, each element can be an integer or a Tensor of shape: [1].
X
xiaoting 已提交
330
             If a Tensor , its dimensions size should be a 1.
331 332 333
        scale_factor (float|Tensor|list|tuple|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`. Has to match input size if it is either a list or a tuple or a Tensor.
334
             Default: None.
335 336
        mode (str): The resample method. It supports 'linear', 'nearst', 'bilinear',
                       'bicubic' and 'trilinear' currently. Default: 'nearest'
337 338 339
        align_corners(bool) :  An optional bool, If True, the centers of the 4 corner pixels of the
                               input and output tensors are aligned, preserving the values at the
                               corner pixels.
340 341 342 343
                               Default: False
        align_mode(int)  :  An optional for linear/bilinear/trilinear interpolation. Refer to the formula in the example above,
                            it can be \'0\' for src_idx = scale_factor*(dst_indx+0.5)-0.5 , can be \'1\' for
                            src_idx = scale_factor*dst_index.
344
        data_format (str, optional): Specify the data format of the input, and the data format of the output
345
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
346 347 348
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
349 350 351
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
352 353 354
    Returns:
        A 3-D Tensor of the shape (num_batches, channels, out_w) or (num_batches, out_w, channels),
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
355
        or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
356 357 358

    Examples:
        .. code-block:: python
359

360
            import paddle
X
xiaoting 已提交
361
            import paddle.nn as nn
362
            import numpy as np
X
xiaoting 已提交
363

364
            input_data = np.random.rand(2,3,6,10).astype("float32")
365
            upsample_out  = paddle.nn.Upsample(size=[12,12])
X
xiaoting 已提交
366 367 368 369 370 371

            input = paddle.to_tensor(input_data)
            output = upsample_out(x=input)
            print(output.shape)
            # [2L, 3L, 12L, 12L]

372 373 374
    """

    def __init__(self,
375 376 377 378
                 size=None,
                 scale_factor=None,
                 mode='nearest',
                 align_corners=False,
X
xiaoting 已提交
379 380 381
                 align_mode=0,
                 data_format='NCHW',
                 name=None):
382
        super(Upsample, self).__init__()
383 384 385
        self.size = size
        self.scale_factor = scale_factor
        self.mode = mode.lower()
386 387 388
        self.align_corners = align_corners
        self.align_mode = align_mode
        self.data_format = data_format
X
xiaoting 已提交
389
        self.name = name
390

X
xiaoting 已提交
391
    def forward(self, x):
392 393 394 395 396 397 398 399
        out = F.interpolate(x,
                            size=self.size,
                            scale_factor=self.scale_factor,
                            mode=self.mode,
                            align_corners=self.align_corners,
                            align_mode=self.align_mode,
                            data_format=self.data_format,
                            name=self.name)
X
xiaoting 已提交
400 401 402

        return out

403 404 405 406 407 408 409 410 411 412
    def extra_repr(self):
        if self.scale_factor is not None:
            main_str = 'scale_factor={}'.format(self.scale_factor)
        else:
            main_str = 'size={}'.format(self.size)
        name_str = ', name={}'.format(self.name) if self.name else ''
        return '{}, mode={}, align_corners={}, align_mode={}, data_format={}{}'.format(
            main_str, self.mode, self.align_corners, self.align_mode,
            self.data_format, name_str)

X
xiaoting 已提交
413

Z
zhiboniu 已提交
414
class UpsamplingNearest2D(Layer):
X
xiaoting 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
    """
    This op upsamples a batch of images, using nearest neighbours' pixel values.
    The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
    where in_w is width of the input tensor, in_h is the height of the input tensor.
    And the upsampling only applies on the two dimensions(height and width).
    Nearest neighbor interpolation is to perform nearest neighbor interpolation
    in both the 3rd dimension(in height direction) and the 4th dimension(in width
    direction) on input tensor.

    For details of nearest neighbor interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.

    Parameters:
        x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
                          its data format is specified by :attr:`data_format`.
        size (list|tuple|Tensor|None): Output shape of image resize
             layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
432
             Default: None. If a list/tuple, each element can be an integer or a Tensor of shape: [1].
X
xiaoting 已提交
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
             If a Tensor , its dimensions size should be a 1.
        scale_factor (float|int|list|tuple|Tensor|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`.
             Has to match input size if it is either a list or a tuple or a Tensor.
             Default: None.
        data_format (str, optional): Specify the data format of the input, and the data format of the output
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
    Returns:
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),


    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
457
            input_data = paddle.rand(shape=(2,3,6,10)).astype("float32")
X
xiaoting 已提交
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476
            upsample_out  = paddle.nn.UpsamplingNearest2D(size=[12,12])
            input = paddle.to_tensor(input_data)
            output = upsample_out(x=input)
            print(output.shape)
            # [2L, 3L, 12L, 12L]
    """

    def __init__(self,
                 size=None,
                 scale_factor=None,
                 data_format='NCHW',
                 name=None):
        super(UpsamplingNearest2D, self).__init__()
        self.size = size
        self.scale_factor = scale_factor
        self.data_format = data_format
        self.name = name

    def forward(self, x):
477 478 479 480 481 482 483 484
        out = F.interpolate(x,
                            size=self.size,
                            scale_factor=self.scale_factor,
                            mode='nearest',
                            align_corners=False,
                            align_mode=0,
                            data_format=self.data_format,
                            name=self.name)
X
xiaoting 已提交
485 486 487

        return out

488 489 490 491 492 493 494 495 496
    def extra_repr(self):
        if self.scale_factor is not None:
            main_str = 'scale_factor={}'.format(self.scale_factor)
        else:
            main_str = 'size={}'.format(self.size)
        name_str = ', name={}'.format(self.name) if self.name else ''
        return '{}, data_format={}{}'.format(main_str, self.data_format,
                                             name_str)

X
xiaoting 已提交
497

Z
zhiboniu 已提交
498
class UpsamplingBilinear2D(Layer):
X
xiaoting 已提交
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
    """
    This op upsamples a batch of images, using bilinear' pixel values.
    The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
    where in_w is width of the input tensor, in_h is the height of the input tensor.
    And the upsampling only applies on the two dimensions(height and width).
    Bilinear interpolation is an extension of linear interpolation for
    interpolating functions of two variables (e.g. H-direction and
    W-direction in this op) on a rectilinear 2D grid. The key idea is
    to perform linear interpolation first in one direction, and then
    again in the other direction.

    For details of bilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bilinear_interpolation.

    Parameters:
        x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
                          its data format is specified by :attr:`data_format`.
        size (list|tuple|Tensor|None): Output shape of image resize
             layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
518
             Default: None. If a list/tuple, each element can be an integer or a Tensor  of shape: [1].
X
xiaoting 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
             If a Tensor , its dimensions size should be a 1.
        scale_factor (float|int|list|tuple|Tensor|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`.
             Has to match input size if it is either a list or a tuple or a Tensor.
             Default: None.
        data_format (str, optional): Specify the data format of the input, and the data format of the output
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
    Returns:
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
542
            input_data = paddle.rand(shape=(2,3,6,10)).astype("float32")
X
xiaoting 已提交
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
            upsample_out  = paddle.nn.UpsamplingBilinear2D(size=[12,12])
            input = paddle.to_tensor(input_data)
            output = upsample_out(x=input)
            print(output.shape)
            # [2L, 3L, 12L, 12L]
    """

    def __init__(self,
                 size=None,
                 scale_factor=None,
                 data_format='NCHW',
                 name=None):
        super(UpsamplingBilinear2D, self).__init__()
        self.size = size
        self.scale_factor = scale_factor
        self.data_format = data_format
        self.name = name

    def forward(self, x):
562 563 564 565 566 567 568 569
        out = F.interpolate(x,
                            size=self.size,
                            scale_factor=self.scale_factor,
                            mode='bilinear',
                            align_corners=True,
                            align_mode=0,
                            data_format=self.data_format,
                            name=self.name)
X
xiaoting 已提交
570 571 572

        return out

573 574 575 576 577 578 579 580 581
    def extra_repr(self):
        if self.scale_factor is not None:
            main_str = 'scale_factor={}'.format(self.scale_factor)
        else:
            main_str = 'size={}'.format(self.size)
        name_str = ', name={}'.format(self.name) if self.name else ''
        return '{}, data_format={}{}'.format(main_str, self.data_format,
                                             name_str)

X
xiaoting 已提交
582

Z
zhiboniu 已提交
583
class Bilinear(Layer):
584
    r"""
585 586 587 588

    This layer performs bilinear on two inputs.

    .. math::
589

590
      out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,outfeatures-1
591

592 593 594 595 596 597
      out = out + b

    In this formula:
     - :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features].
     - :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features].
     - :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features].
598
     - :math:`out_{i}`: the i-th element of out, shape is [batch_size], and out's shape is [batch_size, out_features].
599 600 601 602 603 604 605
     - :math:`b`: the learned bias, shape is [1, out_features].
     - :math:`x2^\mathrm{T}`: the transpose of :math:`x2`.

    Parameters:
       in1_features (int): The dimension of each first input(`x1`).
       in2_features (int): The dimension of each second input(`x2`).
       out_features (int): The dimension of output of this layer.
T
tangwei12 已提交
606
       weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of
607 608 609
       this layer. The default value is None.
       bias_attr (ParamAttr, optional): The parameter attribute for the bias
           of this layer. If it is set to False, no bias will be added to the output units.
T
tangwei12 已提交
610
           If it is set to None, the bias is initialized zero. The default value is None.
611 612 613 614 615 616 617 618 619
       name (str, optional): The default value is None. Normally there is no need for user
           to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

    Attribute:
        **weight** (Parameter): the learnable weights of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Returns:
620
       Tensor: A 2-D Tensor of shape [batch_size, out_features].
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655

    Examples:
       .. code-block:: python

        import paddle
        import numpy

        layer1 = numpy.random.random((5, 5)).astype('float32')
        layer2 = numpy.random.random((5, 4)).astype('float32')
        bilinear = paddle.nn.Bilinear(
            in1_features=5, in2_features=4, out_features=1000)
        result = bilinear(paddle.to_tensor(layer1),
                        paddle.to_tensor(layer2))     # result shape [5, 1000]

    """

    def __init__(self,
                 in1_features,
                 in2_features,
                 out_features,
                 weight_attr=None,
                 bias_attr=None,
                 name=None):
        super(Bilinear, self).__init__()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._name = name
        self._in1_features = in1_features
        self._in2_features = in2_features
        self._out_features = out_features
        self._dtype = self._helper.get_default_dtype()

        weight_shape = [
            self._out_features, self._in1_features, self._in2_features
        ]
656 657 658 659
        self.weight = self.create_parameter(attr=self._weight_attr,
                                            shape=weight_shape,
                                            dtype=self._dtype,
                                            is_bias=False)
660
        bias_shape = [1, self._out_features]
661 662 663 664
        self.bias = self.create_parameter(attr=self._bias_attr,
                                          shape=bias_shape,
                                          dtype=self._dtype,
                                          is_bias=True)
665 666 667 668

    def forward(self, x1, x2):
        return F.bilinear(x1, x2, self.weight, self.bias, self._name)

669 670 671 672 673 674
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'in1_features={}, in2_features={}, out_features={}, dtype={}{}'.format(
            self._in1_features, self._in2_features, self._out_features,
            self._dtype, name_str)

675

Z
zhiboniu 已提交
676
class Dropout(Layer):
677 678 679
    """
    Dropout is a regularization technique for reducing overfitting by preventing
    neuron co-adaption during training as described in the paper:
T
tangwei12 已提交
680
    `Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/abs/1207.0580>`_
681 682 683 684
    The dropout operator randomly sets the outputs of some units to zero, while upscale others
    according to the given dropout probability.

    See ``paddle.nn.functional.dropout`` for more details.
685 686

    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
687 688

    Parameters:
689 690
        p (float|int): Probability of setting units to zero. Default: 0.5
        axis (int|list|tuple): The axis along which the dropout is performed. Default None.
691 692 693 694 695 696 697 698 699 700 701
        mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
702
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
703 704 705 706 707

    Shape:
        - input: N-D tensor.
        - output: N-D tensor, the same shape as input.

708

709 710
    Examples:
        .. code-block:: python
711

712 713 714 715 716 717 718 719 720
            import paddle
            import numpy as np

            x = np.array([[1,2,3], [4,5,6]]).astype('float32')
            x = paddle.to_tensor(x)
            m = paddle.nn.Dropout(p=0.5)
            y_train = m(x)
            m.eval()  # switch the model to test phase
            y_test = m(x)
721 722 723
            print(x)
            print(y_train)
            print(y_test)
724 725 726 727 728 729 730 731 732 733 734
   """

    def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None):
        super(Dropout, self).__init__()

        self.p = p
        self.axis = axis
        self.mode = mode
        self.name = name

    def forward(self, input):
735 736 737 738 739 740
        out = F.dropout(input,
                        p=self.p,
                        axis=self.axis,
                        training=self.training,
                        mode=self.mode,
                        name=self.name)
741 742
        return out

743 744 745 746 747
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'p={}, axis={}, mode={}{}'.format(self.p, self.axis, self.mode,
                                                 name_str)

748

Z
zhiboniu 已提交
749
class Dropout2D(Layer):
750 751 752 753
    """
    Randomly zero out entire channels (in the batched input 4d tensor with the shape `NCHW` ,
    a channel is a 2D feature map with the shape `HW`). Each channel will be zeroed out independently
    on every forward call with probability `p` using samples from a Bernoulli distribution.
C
cnn 已提交
754
    Dropout2D will help promote independence between feature maps as described in the paper:
T
tangwei12 已提交
755
    `Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
756 757 758

    See ``paddle.nn.functional.dropout2d`` for more details.

759 760
    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

761 762
    Parameters:
        p (float, optional): Probability of setting units to zero. Default: 0.5
763
        data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from `NCHW` or `NHWC`. The default is `NCHW`. When it is `NCHW`, the data is stored in the order of: [batch_size, input_channels, input_height, input_width].
764 765 766 767 768 769
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input: 4-D tensor.
        - output: 4-D tensor, the same shape as input.

770

771 772
    Examples:
        .. code-block:: python
773

774 775 776 777 778
            import paddle
            import numpy as np

            x = np.random.random(size=(2, 3, 4, 5)).astype('float32')
            x = paddle.to_tensor(x)
C
cnn 已提交
779
            m = paddle.nn.Dropout2D(p=0.5)
780 781 782
            y_train = m(x)
            m.eval()  # switch the model to test phase
            y_test = m(x)
783 784 785
            print(x)
            print(y_train)
            print(y_test)
786 787 788
   """

    def __init__(self, p=0.5, data_format='NCHW', name=None):
C
cnn 已提交
789
        super(Dropout2D, self).__init__()
790 791 792 793 794 795

        self.p = p
        self.data_format = data_format
        self.name = name

    def forward(self, input):
796 797 798 799 800
        out = F.dropout2d(input,
                          p=self.p,
                          training=self.training,
                          data_format=self.data_format,
                          name=self.name)
801 802
        return out

803 804 805 806 807
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'p={}, data_format={}{}'.format(self.p, self.data_format,
                                               name_str)

808

Z
zhiboniu 已提交
809
class Dropout3D(Layer):
810 811 812 813
    """
    Randomly zero out entire channels (in the batched input 5d tensor with the shape `NCDHW` ,
    a channel is a 3D feature map with the shape `DHW` ). Each channel will be zeroed out independently
    on every forward call with probability `p` using samples from a Bernoulli distribution.
C
cnn 已提交
814
    Dropout3D will help promote independence between feature maps as described in the paper:
T
tangwei12 已提交
815
    `Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
816 817 818

    See ``paddle.nn.functional.dropout3d`` for more details.

819 820
    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

821 822
    Parameters:
        p (float | int): Probability of setting units to zero. Default: 0.5
823
        data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from `NCDHW` or `NDHWC`. The default is `NCDHW`. When it is `NCDHW`, the data is stored in the order of: [batch_size, input_channels, input_depth, input_height, input_width].
824 825 826 827 828 829
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input: 5-D tensor.
        - output: 5-D tensor, the same shape as input.

830

831 832
    Examples:
        .. code-block:: python
833

834 835 836 837 838
            import paddle
            import numpy as np

            x = np.random.random(size=(2, 3, 4, 5, 6)).astype('float32')
            x = paddle.to_tensor(x)
C
cnn 已提交
839
            m = paddle.nn.Dropout3D(p=0.5)
840 841 842
            y_train = m(x)
            m.eval()  # switch the model to test phase
            y_test = m(x)
843 844 845
            print(x)
            print(y_train)
            print(y_test)
846 847 848
   """

    def __init__(self, p=0.5, data_format='NCDHW', name=None):
C
cnn 已提交
849
        super(Dropout3D, self).__init__()
850 851 852 853 854 855

        self.p = p
        self.data_format = data_format
        self.name = name

    def forward(self, input):
856 857 858 859 860
        out = F.dropout3d(input,
                          p=self.p,
                          training=self.training,
                          data_format=self.data_format,
                          name=self.name)
861 862
        return out

863 864 865 866 867
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'p={}, data_format={}{}'.format(self.p, self.data_format,
                                               name_str)

868

Z
zhiboniu 已提交
869
class AlphaDropout(Layer):
870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890
    """
    Alpha Dropout is a type of Dropout that maintains the self-normalizing property. For an input with
    zero mean and unit standard deviation, the output of Alpha Dropout maintains the original mean and
    standard deviation of the input. Alpha Dropout fits well to SELU activate function by randomly setting
    activations to the negative saturation value.

    For more information, please refer to:
    `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_

    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

    Parameters:
        p (float | int): Probability of setting units to zero. Default: 0.5
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input: N-D tensor.
        - output: N-D tensor, the same shape as input.

    Examples:
        .. code-block:: python
891

892 893 894 895 896 897 898 899 900
            import paddle
            import numpy as np

            x = np.array([[-1, 1], [-1, 1]]).astype('float32')
            x = paddle.to_tensor(x)
            m = paddle.nn.AlphaDropout(p=0.5)
            y_train = m(x)
            m.eval()  # switch the model to test phase
            y_test = m(x)
901 902
            print(x)
            print(y_train)
903
            # [[-0.10721093, 1.6655989 ], [-0.7791938, -0.7791938]] (randomly)
904
            print(y_test)
905 906 907 908 909 910 911 912
   """

    def __init__(self, p=0.5, name=None):
        super(AlphaDropout, self).__init__()
        self.p = p
        self.name = name

    def forward(self, input):
913 914 915 916
        out = F.alpha_dropout(input,
                              p=self.p,
                              training=self.training,
                              name=self.name)
917 918
        return out

919 920 921 922
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'p={}{}'.format(self.p, name_str)

923

Z
zhiboniu 已提交
924
class Pad1D(Layer):
L
littletomatodonkey 已提交
925
    """
L
littletomatodonkey 已提交
926 927 928
    This interface is used to construct a callable object of the ``Pad1D`` class.
    Pad tensor according to 'pad', 'mode' and 'value'.
    If mode is 'reflect', pad[0] and pad[1] must be no greater than width-1.
L
littletomatodonkey 已提交
929 930

    Parameters:
931
        padding (Tensor|list[int]|int): The padding size with data type int. If is int, use the
932
            same padding in both dimensions. Else [len(padding)/2] dimensions
L
littletomatodonkey 已提交
933
            of input will be padded. The pad has the form (pad_left, pad_right).
934 935 936 937 938 939 940 941 942
        mode (str, optional): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. Default is 'constant'.

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`。
        data_format (str, optional): An string from: "NCL", "NLC". Specify the data format of the input data.
L
littletomatodonkey 已提交
943
           Default is  "NCL"
944
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
945 946

    Returns:
L
littletomatodonkey 已提交
947 948 949 950
        None

    Examples:
        .. code-block:: python
951

L
littletomatodonkey 已提交
952 953 954 955 956
            import paddle
            import paddle.nn as nn

            input_shape = (1, 2, 3)
            pad = [1, 2]
L
littletomatodonkey 已提交
957
            mode = "constant"
958
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
959
            my_pad = nn.Pad1D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
960
            result = my_pad(data)
L
littletomatodonkey 已提交
961
            print(result)
L
littletomatodonkey 已提交
962 963 964 965
            # [[[0. 1. 2. 3. 0. 0.]
            #   [0. 4. 5. 6. 0. 0.]]]
    """

L
littletomatodonkey 已提交
966 967 968 969 970 971 972
    def __init__(self,
                 padding,
                 mode='constant',
                 value=0.0,
                 data_format="NCL",
                 name=None):
        super(Pad1D, self).__init__()
973
        self._pad = _npairs(padding, 1)
L
littletomatodonkey 已提交
974
        self._mode = mode
L
littletomatodonkey 已提交
975
        self._value = value
L
littletomatodonkey 已提交
976
        self._data_format = data_format
L
littletomatodonkey 已提交
977 978 979 980 981 982 983 984 985 986
        self._name = name

    def forward(self, x):
        return F.pad(x,
                     pad=self._pad,
                     mode=self._mode,
                     value=self._value,
                     data_format=self._data_format,
                     name=self._name)

987 988 989 990 991
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
            self._pad, self._mode, self._value, self._data_format, name_str)

L
littletomatodonkey 已提交
992

Z
zhiboniu 已提交
993
class Pad2D(Layer):
L
littletomatodonkey 已提交
994
    """
L
littletomatodonkey 已提交
995 996 997 998
    This interface is used to construct a callable object of the ``Pad2D`` class.
    Pad tensor according to 'pad', 'mode' and 'value'.
    If mode is 'reflect', pad[0] and pad[1] must be no greater
    than width-1. The height dimension has the same condition.
L
littletomatodonkey 已提交
999 1000

    Parameters:
1001
        padding (Tensor|list[int]|int): The padding size with data type int. If is int, use the
1002 1003
            same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
            The pad has the form (pad_left, pad_right, pad_top, pad_bottom).
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014
        mode (str, optional): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. Default is 'constant'.

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`。
        data_format (str, optional): An string from: "NCHW", "NHWC". Specify the data format of the input data.
           Default is  "NCHW"。
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1015 1016

    Returns:
L
littletomatodonkey 已提交
1017 1018 1019 1020
        None

    Examples:
        .. code-block:: python
1021

L
littletomatodonkey 已提交
1022 1023
            import paddle
            import paddle.nn as nn
1024

L
littletomatodonkey 已提交
1025 1026
            input_shape = (1, 1, 2, 3)
            pad = [1, 0, 1, 2]
L
littletomatodonkey 已提交
1027
            mode = "constant"
1028
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1029
            my_pad = nn.Pad2D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1030
            result = my_pad(data)
L
littletomatodonkey 已提交
1031
            print(result)
L
littletomatodonkey 已提交
1032 1033 1034 1035 1036 1037 1038
            # [[[[0. 0. 0. 0.]
            #    [0. 1. 2. 3.]
            #    [0. 4. 5. 6.]
            #    [0. 0. 0. 0.]
            #    [0. 0. 0. 0.]]]]
    """

L
littletomatodonkey 已提交
1039 1040 1041 1042 1043 1044 1045
    def __init__(self,
                 padding,
                 mode='constant',
                 value=0.0,
                 data_format="NCHW",
                 name=None):
        super(Pad2D, self).__init__()
1046
        self._pad = _npairs(padding, 2)
L
littletomatodonkey 已提交
1047
        self._mode = mode
L
littletomatodonkey 已提交
1048 1049 1050 1051 1052 1053 1054 1055
        self._value = value
        self._data_format = data_format
        self._name = name

    def forward(self, x):
        return F.pad(x,
                     pad=self._pad,
                     mode=self._mode,
L
littletomatodonkey 已提交
1056
                     value=self._value,
L
littletomatodonkey 已提交
1057 1058 1059
                     data_format=self._data_format,
                     name=self._name)

1060 1061 1062 1063 1064
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
            self._pad, self._mode, self._value, self._data_format, name_str)

L
littletomatodonkey 已提交
1065

1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127
class ZeroPad2D(Layer):
    """
    This interface is used to construct a callable object of the ``ZeroPad2D`` class.
    Pads the input tensor boundaries with zero.

    Parameters:
        padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
            same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
            The pad has the form (pad_left, pad_right, pad_top, pad_bottom).
        data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data.
           Default is  "NCHW"
        name (str, optional) : The default value is None.  Normally there is no need for
            user to set this property.  For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - x(Tensor): The input tensor of zeropad2d operator, which is a 4-D tensor.
          The data type can be float32, float64.
        - output(Tensor): The output tensor of zeropad2d operator, which is a 4-D tensor.
          The data type is same as input x.

    Examples:
        Examples are as follows.

        .. code-block:: python

            import paddle
            import paddle.nn as nn
            import numpy as np

            input_shape = (1, 1, 2, 3)
            pad = [1, 0, 1, 2]
            data = paddle.arange(np.prod(input_shape), dtype="float32").reshape(input_shape) + 1

            my_pad = nn.ZeroPad2D(padding=pad)
            result = my_pad(data)

            print(result)
            # [[[[0. 0. 0. 0.]
            #    [0. 1. 2. 3.]
            #    [0. 4. 5. 6.]
            #    [0. 0. 0. 0.]
            #    [0. 0. 0. 0.]]]]
    """

    def __init__(self, padding, data_format="NCHW", name=None):
        super(ZeroPad2D, self).__init__()
        self._pad = _npairs(padding, 2)
        self._mode = 'constant'
        self._value = 0.
        self._data_format = data_format
        self._name = name

    def forward(self, x):
        return F.pad(x,
                     pad=self._pad,
                     mode=self._mode,
                     value=self._value,
                     data_format=self._data_format,
                     name=self._name)

    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
1128 1129 1130
        return 'padding={}, data_format={}{}'.format(self._pad,
                                                     self._data_format,
                                                     name_str)
1131 1132


Z
zhiboniu 已提交
1133
class Pad3D(Layer):
L
littletomatodonkey 已提交
1134
    """
L
littletomatodonkey 已提交
1135 1136 1137 1138
    This interface is used to construct a callable object of the ``Pad3D`` class.
    Pad tensor according to 'pad', 'mode' and 'value'.
    If mode is 'reflect', pad[0] and pad[1] must be no greater
    than width-1. The height and depth dimension has the same condition.
L
littletomatodonkey 已提交
1139 1140

    Parameters:
1141
        padding (Tensor|list[int]|int): The padding size with data type int. If is int, use the
1142
            same padding in all dimensions. Else [len(padding)/2] dimensions
L
littletomatodonkey 已提交
1143
            of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back).
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154
        mode (str, optional): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. Default is 'constant'.

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`。
        data_format (str, optional): An string from: "NCDHW", "NDHWC". Specify the data format of the input data.
           Default is  "NCDHW"。
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1155 1156

    Returns:
L
littletomatodonkey 已提交
1157 1158 1159 1160
        None

    Examples:
        .. code-block:: python
1161

L
littletomatodonkey 已提交
1162 1163
            import paddle
            import paddle.nn as nn
1164

L
littletomatodonkey 已提交
1165 1166
            input_shape = (1, 1, 1, 2, 3)
            pad = [1, 0, 1, 2, 0, 0]
L
littletomatodonkey 已提交
1167
            mode = "constant"
1168
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1169
            my_pad = nn.Pad3D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1170
            result = my_pad(data)
L
littletomatodonkey 已提交
1171
            print(result)
L
littletomatodonkey 已提交
1172 1173 1174 1175 1176 1177 1178
            # [[[[[0. 0. 0. 0.]
            #     [0. 1. 2. 3.]
            #     [0. 4. 5. 6.]
            #     [0. 0. 0. 0.]
            #     [0. 0. 0. 0.]]]]]
    """

L
littletomatodonkey 已提交
1179 1180 1181 1182 1183 1184 1185
    def __init__(self,
                 padding,
                 mode='constant',
                 value=0.0,
                 data_format="NCDHW",
                 name=None):
        super(Pad3D, self).__init__()
1186
        self._pad = _npairs(padding, 3)
L
littletomatodonkey 已提交
1187
        self._mode = mode
L
littletomatodonkey 已提交
1188 1189 1190 1191 1192 1193 1194 1195
        self._value = value
        self._data_format = data_format
        self._name = name

    def forward(self, x):
        return F.pad(x,
                     pad=self._pad,
                     mode=self._mode,
L
littletomatodonkey 已提交
1196
                     value=self._value,
L
littletomatodonkey 已提交
1197 1198 1199
                     data_format=self._data_format,
                     name=self._name)

1200 1201 1202 1203 1204
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
            self._pad, self._mode, self._value, self._data_format, name_str)

L
littletomatodonkey 已提交
1205

Z
zhiboniu 已提交
1206
class CosineSimilarity(Layer):
L
littletomatodonkey 已提交
1207
    """
1208
    This interface is used to compute cosine similarity between x1 and x2 along axis.
L
littletomatodonkey 已提交
1209 1210

    Parameters:
1211
        axis (int): Dimension of vectors to compute cosine similarity. Default is 1.
L
littletomatodonkey 已提交
1212
        eps(float): Small value to avoid division by zero. Default is 1e-8.
1213
    Returns:
L
littletomatodonkey 已提交
1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
        None

    Examples:
        .. code-block:: text

            Case 0:
                x1 = [[0.8024077  0.9927354  0.27238318 0.8344984 ]
                     [0.48949873 0.5797396  0.65444374 0.66510963]
                     [0.1031398  0.9614342  0.08365563 0.6796464 ]
                     [0.10760343 0.7461209  0.7726148  0.5801006 ]]
                x2 = [[0.62913156 0.1536727  0.9847992  0.04591406]
                     [0.9098952  0.15715368 0.8671125  0.3156102 ]
                     [0.4427798  0.54136837 0.5276275  0.32394758]
                     [0.3769419  0.8535014  0.48041078 0.9256797 ]]
1228
                axis = 1
L
littletomatodonkey 已提交
1229 1230 1231 1232 1233
                eps = 1e-8
                Out: [0.5275037  0.8368967  0.75037485 0.9245899]

    Code Examples:
        .. code-block:: python
1234

L
littletomatodonkey 已提交
1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
            import paddle
            import paddle.nn as nn
            import numpy as np

            np.random.seed(0)
            x1 = np.random.rand(2,3)
            x2 = np.random.rand(2,3)
            x1 = paddle.to_tensor(x1)
            x2 = paddle.to_tensor(x2)

1245
            cos_sim_func = nn.CosineSimilarity(axis=0)
L
littletomatodonkey 已提交
1246
            result = cos_sim_func(x1, x2)
L
littletomatodonkey 已提交
1247
            print(result)
L
littletomatodonkey 已提交
1248 1249 1250
            # [0.99806249 0.9817672  0.94987036]
    """

1251
    def __init__(self, axis=1, eps=1e-8):
L
littletomatodonkey 已提交
1252
        super(CosineSimilarity, self).__init__()
1253
        self._axis = axis
L
littletomatodonkey 已提交
1254 1255 1256
        self._eps = eps

    def forward(self, x1, x2):
1257
        return F.cosine_similarity(x1, x2, axis=self._axis, eps=self._eps)
T
tangwei12 已提交
1258

1259 1260 1261
    def extra_repr(self):
        return 'axis={_axis}, eps={_eps}'.format(**self.__dict__)

T
tangwei12 已提交
1262

Z
zhiboniu 已提交
1263
class Embedding(Layer):
1264
    r"""
1265

1266
    Embedding Layer, used to construct a callable object of the ``Embedding`` class.
T
tangwei12 已提交
1267
    For specific usage, refer to code examples. It implements the function of the Embedding Layer.
T
tangwei12 已提交
1268
    This layer is used to lookup embeddings vector of ids provided by :attr:`x` .
T
tangwei12 已提交
1269
    It automatically constructs a 2D embedding matrix based on the
T
tangwei12 已提交
1270
    input :attr:`num_embeddings` and :attr:`embedding_dim`.
T
tangwei12 已提交
1271 1272 1273 1274

    The shape of output Tensor is generated by appending an emb_size dimension to the
    last dimension of the input Tensor shape.

1275 1276 1277
    Note:
        The id in :attr:`x` must satisfy :math:`0 =< id < num_embeddings` ,
        otherwise the program will throw an exception and exit.
T
tangwei12 已提交
1278 1279 1280 1281 1282

    .. code-block:: text

        Case 1:

T
tangwei12 已提交
1283 1284 1285
        x is a Tensor. padding_idx = -1
            x.data = [[1, 3], [2, 4], [4, 127]
            x.shape = [3, 2]
T
tangwei12 已提交
1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302
        Given size = [128, 16]
        output is a Tensor:
            out.shape = [3, 2, 16]
            out.data = [[[0.129435295, 0.244512452, ..., 0.436322452],
                        [0.345421456, 0.524563927, ..., 0.144534654]],

                        [[0.345249859, 0.124939536, ..., 0.194353745],
                        [0.945345345, 0.435394634, ..., 0.435345365]],

                        [[0.945345345, 0.435394634, ..., 0.435345365],
                        [0.0,         0.0,         ..., 0.0        ]]]  # padding data
        The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127
        It will pad all-zero data when ids is 127.

    Parameters:
        num_embeddings (int): Just one element which indicate the size
            of the dictionary of embeddings.
T
tangwei12 已提交
1303
        embedding_dim (int):  Just one element which indicate the size of each embedding vector respectively.
1304
        padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-num_embeddings, num_embeddings).
T
tangwei12 已提交
1305 1306 1307 1308
            If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
            to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
            encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
            If set None, it makes no effect to output. Default: None.
1309
        sparse(bool, optional): The flag indicating whether to use sparse update. This parameter only
T
tangwei12 已提交
1310 1311
            affects the performance of the backwards gradient update. It is recommended to set
            True because sparse update is faster. But some optimizer does not support sparse update,
T
tangwei12 已提交
1312
            such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`.
T
tangwei12 已提交
1313
            In these case, sparse must be False. Default: False.
1314
        weight_attr(ParamAttr, optional): To specify the weight parameter property. Default: None, which means the
T
tangwei12 已提交
1315
            default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
T
tangwei12 已提交
1316 1317
            user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
            The local word vector needs to be transformed into numpy format, and the shape of local word
T
tangwei12 已提交
1318 1319
            vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
            is used to load custom or pre-trained word vectors. See code example for details.
1320
        name(str|None, optional): For detailed information, please refer
T
tangwei12 已提交
1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333
               to :ref:`api_guide_Name`. Usually name is no need to set and
               None by default.

    Attribute:
        **weight** (Parameter): the learnable weights of this layer.

    Returns:
        None

    Examples:

        .. code-block:: python

T
tangwei12 已提交
1334 1335 1336 1337 1338
            import paddle
            import numpy as np

            x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
            y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
T
tangwei12 已提交
1339

T
tangwei12 已提交
1340 1341 1342 1343 1344 1345 1346
            x = paddle.to_tensor(x_data, stop_gradient=False)
            y = paddle.to_tensor(y_data, stop_gradient=False)

            embedding = paddle.nn.Embedding(10, 3, sparse=True)

            w0=np.full(shape=(10, 3), fill_value=2).astype(np.float32)
            embedding.weight.set_value(w0)
T
tangwei12 已提交
1347

T
tangwei12 已提交
1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
            adam = paddle.optimizer.Adam(parameters=[embedding.weight], learning_rate=0.01)
            adam.clear_grad()

            # weight.shape = [10, 3]

            # x.data = [[3],[4],[5]]
            # x.shape = [3, 1]

            # out.data = [[2,2,2], [2,2,2], [2,2,2]]
            # out.shape = [3, 1, 3]
            out=embedding(x)
            out.backward()
            adam.step()
T
tangwei12 已提交
1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375

    """

    def __init__(self,
                 num_embeddings,
                 embedding_dim,
                 padding_idx=None,
                 sparse=False,
                 weight_attr=None,
                 name=None):
        super(Embedding, self).__init__()
        self._num_embeddings = num_embeddings
        self._embedding_dim = embedding_dim
        self._sparse = sparse
        self._is_distributed = False
1376
        self._padding_idx = padding_idx
T
tangwei12 已提交
1377 1378 1379 1380 1381 1382 1383

        if self._num_embeddings <= 0:
            raise ValueError("num_embeddings must be gather than 0")

        if self._embedding_dim <= 0:
            raise ValueError("embedding_dim must be gather than 0")

1384 1385 1386 1387
        padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
            num_embeddings + padding_idx)

        if padding_idx >= num_embeddings or padding_idx < -num_embeddings:
T
tangwei12 已提交
1388 1389 1390
            raise ValueError("padding_idx must be within [-{}, {})".format(
                num_embeddings, num_embeddings))

T
tangwei12 已提交
1391 1392 1393 1394 1395 1396
        self._dtype = self._helper.get_default_dtype()
        self._size = [self._num_embeddings, self._embedding_dim]

        self._weight_attr = weight_attr
        self._remote_prefetch = False
        self._name = name
1397 1398 1399 1400
        self.weight = self.create_parameter(attr=self._weight_attr,
                                            shape=self._size,
                                            dtype=self._dtype,
                                            is_bias=False)
T
tangwei12 已提交
1401

Z
zhiboniu 已提交
1402
        if in_dynamic_mode() and padding_idx != -1:
1403 1404
            with paddle.no_grad():
                self.weight[padding_idx] = 0.0
T
tangwei12 已提交
1405

T
tangwei12 已提交
1406
    def forward(self, x):
1407 1408 1409 1410 1411
        return F.embedding(x,
                           weight=self.weight,
                           padding_idx=self._padding_idx,
                           sparse=self._sparse,
                           name=self._name)
1412 1413 1414 1415 1416 1417 1418 1419 1420

    def extra_repr(self):
        main_str = '{_num_embeddings}, {_embedding_dim}'
        if self._padding_idx is not None:
            main_str += ', padding_idx={_padding_idx}'
        main_str += ', sparse={_sparse}'
        if self._name is not None:
            main_str += ', name={_name}'
        return main_str.format(**self.__dict__)
F
FNRE 已提交
1421 1422


Z
zhiboniu 已提交
1423
class Unfold(Layer):
F
FNRE 已提交
1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434
    """
    This op returns a col buffer of sliding local blocks of input x, also known
    as im2col for batched 2D image tensors. For each block under the convolution filter,
    all element will be rearranged as a column. While the convolution filter sliding over
    the input feature map, a series of such columns will be formed.

    For each input :math:`x` with shape [N, C, H, W], the output shape [N, Cout, Lout]
    can be calculated as following.

    See ``paddle.nn.functional.unfold`` for more details.

1435

F
FNRE 已提交
1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
    Parameters:
        kernel_sizes(int|list):   The size of convolution kernel, should be [k_h, k_w]
                                  or an integer k treated as [k, k].
        strides(int|list):        The strides, should be [stride_h, stride_w]
                                  or an integer stride treated as [sride, stride].
                                  For default, strides will be [1, 1].
        paddings(int|list):       The paddings of each dimension, should be
                                  [padding_top, padding_left, padding_bottom, padding_right]
                                  or [padding_h, padding_w] or an integer padding.
                                  If [padding_h, padding_w] was given, it will expanded to
                                  [padding_h, padding_w, padding_h, padding_w]. If an integer
                                  padding was given, [padding, padding, padding, padding] will
                                  be used. For default, paddings will be [0, 0, 0, 0]
        dilations(int|list):      the dilations of convolution kernel, should be
                                  [dilation_h, dilation_w], or an integer dilation treated as
                                  [dilation, dilation]. For default, it will be [1, 1].
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`


    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            x = paddle.randn((100,3,224,224))
            unfold = nn.Unfold(kernel_sizes=[3, 3])
            result = unfold(x)
            print(result)
X
xiaoting 已提交
1467
    """
F
FNRE 已提交
1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483

    def __init__(self,
                 kernel_sizes,
                 dilations=1,
                 paddings=0,
                 strides=1,
                 name=None):
        super(Unfold, self).__init__()

        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.paddings = paddings
        self.strides = strides
        self.name = name

    def forward(self, input):
1484 1485 1486 1487 1488 1489
        return F.unfold(input,
                        kernel_sizes=self.kernel_sizes,
                        strides=self.strides,
                        paddings=self.paddings,
                        dilations=self.dilations,
                        name=self.name)
F
FNRE 已提交
1490 1491 1492 1493 1494

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'kernel_size={}, dilation={}, padding={}, stride={}{}'.\
                format(self.kernel_sizes, self.dilations, self.paddings, self.strides, name_str)
X
xiaoting 已提交
1495 1496 1497


class Fold(Layer):
1498
    r"""
X
xiaoting 已提交
1499

1500
    Combines an array of sliding local blocks into a large containing
1501 1502
    tensor. also known as col2im when operated on batched 2D image tensor. Fold calculates each
    combined value in the resulting large tensor by summing all values from all containing blocks.
X
xiaoting 已提交
1503 1504 1505 1506 1507 1508


    For each input :math:`x` with shape [N, C_in , L], the output shape [N, C_out, H_out, W_out]
    can be calculated as following.

    .. math::
1509

1510 1511 1512
        H_{out} &= output\_size[0] \\
        W_{out} &= output\_size[1] \\
        C_{out} &= \frac{C_{in}}{kernel\_sizes[0]\times kernel\_sizes[1]} \\
X
xiaoting 已提交
1513 1514 1515 1516

    Parameters:
        output_sizes(list):       The size of output size, should be [output_size_h, output_size_w]
                                  or an interger o treated as [o, o].
X
xiaoting 已提交
1517
        kernel_sizes(int|list|tuple):   The size of convolution kernel, should be [k_h, k_w]
X
xiaoting 已提交
1518
                                  or an integer k treated as [k, k].
1519
        strides(int|list|tuple, optional):        The strides, should be [stride_h, stride_w]
X
xiaoting 已提交
1520 1521
                                  or an integer stride treated as [sride, stride].
                                  For default, strides will be [1, 1].
1522
        paddings(int|list|tuple, optional):       The paddings of each dimension, should be
X
xiaoting 已提交
1523 1524 1525 1526 1527 1528
                                  [padding_top, padding_left, padding_bottom, padding_right]
                                  or [padding_h, padding_w] or an integer padding.
                                  If [padding_h, padding_w] was given, it will expanded to
                                  [padding_h, padding_w, padding_h, padding_w]. If an integer
                                  padding was given, [padding, padding, padding, padding] will
                                  be used. For default, paddings will be [0, 0, 0, 0]
1529
        dilations(int|list|tuple, optional):      the dilations of convolution kernel, should be
X
xiaoting 已提交
1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547
                                  [dilation_h, dilation_w], or an integer dilation treated as
                                  [dilation, dilation]. For default, it will be [1, 1].
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`


    Returns:
        The tensor formed by combining a group of sliding local blocks
        The output shape is [N, Cout, H, W] as decriabled above.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
1548 1549
            x = paddle.randn([2,3*2*2,12])
            fold = nn.Fold(output_sizes=[4, 5], kernel_sizes=2)
X
xiaoting 已提交
1550
            y = fold(x)
X
xiaoting 已提交
1551
            # y.shape = [2,3,4,5]
X
xiaoting 已提交
1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570
   """

    def __init__(self,
                 output_sizes,
                 kernel_sizes,
                 dilations=1,
                 paddings=0,
                 strides=1,
                 name=None):
        super(Fold, self).__init__()

        self.output_sizes = output_sizes
        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.paddings = paddings
        self.strides = strides
        self.name = name

    def forward(self, input):
1571 1572 1573 1574 1575 1576 1577
        return F.fold(input,
                      output_sizes=self.output_sizes,
                      kernel_sizes=self.kernel_sizes,
                      strides=self.strides,
                      paddings=self.paddings,
                      dilations=self.dilations,
                      name=self.name)
X
xiaoting 已提交
1578 1579 1580 1581 1582

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'kernel_size={}, dilation={}, padding={}, stride={}{}'.\
                format(self.kernel_sizes, self.dilations, self.paddings, self.strides, name_str)