common.py 65.7 KB
Newer Older
S
shiyutang 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
S
shiyutang 已提交
7
#    http://www.apache.org/licenses/LICENSE-2.0
8 9 10 11 12 13 14
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# TODO: define the common classes to build a neural network
16
import paddle
17 18 19
from paddle import in_dynamic_mode
from paddle.nn import Layer

Z
zhiboniu 已提交
20
from ...fluid.dygraph import Flatten  # noqa: F401
21
from .. import functional as F
22

23 24
__all__ = []

25

26
def _npairs(x, n):
27
    if isinstance(x, (paddle.Tensor, list, tuple)):
28 29 30 31 32
        return x
    x = [x] * (n * 2)
    return x


S
shiyutang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
class Identity(Layer):
    r"""

    A placeholder identity operator that is argument-insensitive. For each input :math:`X` ,
    the output :math:`Out` is:

    .. math::

        Out = X

    Parameters:
        args: any argument (unused)
        kwargs: any keyword argument (unused)

    Shape:
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, n1, n2, ...]` .
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, n1, n2, ...]` .

    Examples:
        .. code-block:: python

          import paddle

          input_tensor = paddle.randn(shape=[3, 2])
          layer = paddle.nn.Identity()
          out = layer(input_tensor)
          # input_tensor: [[-0.32342386 -1.200079  ]
          #                [ 0.7979031  -0.90978354]
          #                [ 0.40597573  1.8095392 ]]
          # out: [[-0.32342386 -1.200079  ]
          #      [ 0.7979031  -0.90978354]
          #      [ 0.40597573  1.8095392 ]]


    """

    def __init__(self, *args, **kwargs):
70
        super().__init__()
S
shiyutang 已提交
71 72 73 74 75

    def forward(self, input):
        return input


Z
zhiboniu 已提交
76
class Linear(Layer):
77
    r"""
78 79 80

    Fully-connected linear transformation layer. For each input :math:`X` ,
    the equation is:
81 82 83

    .. math::

84
        Out = XW + b
85

86
    where :math:`W` is the weight and :math:`b` is the bias.
87

88 89 90 91 92 93 94
    Linear layer takes only one multi-dimensional tensor as input with the
    shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
    number of additional dimensions. It multiplies input tensor with the weight
    (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
    an output tensor of shape :math:`[batch\_size, *, out\_features]` .
    If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
    shape :math:`[out\_features]` ) will be created and added to the output.
95 96

    Parameters:
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        in_features (int): The number of input units.
        out_features (int): The number of output units.
        weight_attr (ParamAttr, optional): The attribute for the learnable
            weight of this layer. The default value is None and the weight will be
            initialized to zero. For detailed information, please refer to
            paddle.ParamAttr.
        bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
            of this layer. If it is set to False, no bias will be added to the output.
            If it is set to None or one kind of ParamAttr, a bias parameter will
            be created according to ParamAttr. For detailed information, please refer
            to paddle.ParamAttr. The default value is None and the bias will be
            initialized to zero.
        name (str, optional): Normally there is no need for user to set this parameter.
            For detailed information, please refer to :ref:`api_guide_Name` .

    Attribute:
        **weight** (Parameter): the learnable weight of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Shape:
        - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` .
        - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` .
120 121 122 123 124

    Examples:
        .. code-block:: python

          import paddle
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

          # Define the linear layer.
          weight_attr = paddle.ParamAttr(
              name="weight",
              initializer=paddle.nn.initializer.Constant(value=0.5))
          bias_attr = paddle.ParamAttr(
              name="bias",
              initializer=paddle.nn.initializer.Constant(value=1.0))
          linear = paddle.nn.Linear(2, 4, weight_attr=weight_attr, bias_attr=bias_attr)
          # linear.weight: [[0.5 0.5 0.5 0.5]
          #                 [0.5 0.5 0.5 0.5]]
          # linear.bias: [1. 1. 1. 1.]

          x = paddle.randn((3, 2), dtype="float32")
          # x: [[-0.32342386 -1.200079  ]
          #     [ 0.7979031  -0.90978354]
          #     [ 0.40597573  1.8095392 ]]
          y = linear(x)
          # y: [[0.23824859 0.23824859 0.23824859 0.23824859]
          #     [0.9440598  0.9440598  0.9440598  0.9440598 ]
          #     [2.1077576  2.1077576  2.1077576  2.1077576 ]]
146 147
    """

148 149 150 151 152 153 154 155
    def __init__(
        self,
        in_features,
        out_features,
        weight_attr=None,
        bias_attr=None,
        name=None,
    ):
156
        super().__init__()
157 158 159
        self._dtype = self._helper.get_default_dtype()
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
160 161 162 163 164 165 166 167 168 169 170 171
        self.weight = self.create_parameter(
            shape=[in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False,
        )
        self.bias = self.create_parameter(
            shape=[out_features],
            attr=self._bias_attr,
            dtype=self._dtype,
            is_bias=True,
        )
172 173 174
        self.name = name

    def forward(self, input):
175 176 177
        out = F.linear(
            x=input, weight=self.weight, bias=self.bias, name=self.name
        )
178 179
        return out

180 181 182
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'in_features={}, out_features={}, dtype={}{}'.format(
183 184
            self.weight.shape[0], self.weight.shape[1], self._dtype, name_str
        )
185

186

Z
zhiboniu 已提交
187
class Upsample(Layer):
188 189
    """
    This op resizes a batch of images.
190

191 192 193
    The input must be a 3-D Tensor of the shape (num_batches, channels, in_w)
    or 4-D (num_batches, channels, in_h, in_w), or a 5-D Tensor of the shape
    (num_batches, channels, in_d, in_h, in_w) or (num_batches, in_d, in_h, in_w, channels),
194 195
    Where in_w is width of the input tensor, in_h is the height of the input tensor,
    in_d is the depth of the intput tensor.
196
    and the resizing only applies on the three dimensions(depth, height and width).
X
xiaoting 已提交
197

198
    Supporting resample methods:
199 200 201 202 203 204
        'linear' : Linear interpolation
        'bilinear' : Bilinear interpolation
        'trilinear' : Trilinear interpolation
        'nearest' : Nearest neighbor interpolation
        'bicubic' : Bicubic interpolation

T
tangwei12 已提交
205 206 207
    Linear interpolation is the method of using a line connecting two known quantities
    to determine the value of an unknown quantity between the two known quantities.

208 209 210 211 212 213 214 215 216
    Nearest neighbor interpolation is to perform nearest neighbor interpolation
    in both the 3rd dimension(in height direction) and the 4th dimension(in width
    direction) on input tensor.

    Bilinear interpolation is an extension of linear interpolation for
    interpolating functions of two variables (e.g. H-direction and
    W-direction in this op) on a rectilinear 2D grid. The key idea is
    to perform linear interpolation first in one direction, and then
    again in the other direction.
T
tangwei12 已提交
217

218 219 220 221
    Bicubic interpolation is an extension of cubic interpolation for interpolating
    data points on a two-dimensional regular grid. The interpolated surface is
    smoother than corresponding surfaces obtained by bilinear interpolation or
    nearest-neighbor interpolation.
222 223 224 225 226

    Trilinear interpolation is an extension of linear interpolation for
    interpolating functions of three variables (e.g. D-direction,
    H-direction and W-direction in this op) on a rectilinear 3D grid.
    The linear interpolation is performed on three directions.
X
xiaoting 已提交
227
    align_corners and align_mode are optional parameters,the calculation method
228 229
    of interpolation can be selected by them.

230 231 232 233 234 235
    Area interpolation is to perform area interpolation
    in both the 3rd dimension(in height direction) , the 4th dimension(in width
    direction) and the 5th dimension(in depth direction) on input tensor. Set to
    area will directly call `paddle.nn.functional.adaptive_avg_pool1d` or
    `paddle.nn.functional.adaptive_avg_pool2d` or `paddle.nn.functional.adaptive_avg_pool3d`.

236 237 238 239
    Example:

    .. code-block:: text

240
        For scale_factor:
241 242 243 244 245
            if align_corners = True && out_size > 1 :
              scale_factor = (in_size-1.0)/(out_size-1.0)
            else:
              scale_factor = float(in_size/out_size)

246 247 248 249 250 251 252 253 254 255
        Linear interpolation:
            if:
                align_corners = False , align_mode = 0
                input : (N,C,W_in)
                output: (N,C,W_out) where:
                W_out = (W_{in}+0.5) * scale_{factor} - 0.5
            else:
                input : (N,C,W_in)
                output: (N,C,W_out) where:
                W_out = W_{in} * scale_{factor}
256 257 258 259 260 261 262 263 264 265 266 267 268 269

        Nearest neighbor interpolation:
          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = floor (H_{in} * scale_{factor})
              W_out = floor (W_{in} * scale_{factor})
          else:
              align_corners = True
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = round(H_{in} * scale_{factor})
              W_out = round(W_{in} * scale_{factor})
T
tangwei12 已提交
270

271 272 273
        Bilinear interpolation:
          if:
              align_corners = False , align_mode = 0
274

275 276 277 278 279
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
280

281 282 283 284 285 286 287 288 289 290 291 292
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

        Bicubic interpolation:
          if:
              align_corners = False
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
293

294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
          else:
              input : (N,C,H_in,W_in)
              output: (N,C,H_out,W_out) where:
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

        Trilinear interpolation:
          if:
              align_corners = False , align_mode = 0
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              D_out = (D_{in}+0.5) * scale_{factor} - 0.5
              H_out = (H_{in}+0.5) * scale_{factor} - 0.5
              W_out = (W_{in}+0.5) * scale_{factor} - 0.5
          else:
              input : (N,C,D_in,H_in,W_in)
              output: (N,C,D_out,H_out,W_out) where:
              D_out = D_{in} * scale_{factor}
              H_out = H_{in} * scale_{factor}
              W_out = W_{in} * scale_{factor}

315 316
    https://en.wikipedia.org/wiki/Linear_interpolation.
    For details of linear interpolation, please refer to Wikipedia:
T
tangwei12 已提交
317

318 319
    For details of nearest neighbor interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.
T
tangwei12 已提交
320

321 322
    For details of bilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bilinear_interpolation.
T
tangwei12 已提交
323

324 325
    For details of bicubic interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bicubic_interpolation
T
tangwei12 已提交
326

327 328
    For details of trilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Trilinear_interpolation.
T
tangwei12 已提交
329

330
    Parameters:
X
xiaoting 已提交
331
        x (Tensor): 3-D, 4-D or 5-D Tensor, its data type is float32, float64, or uint8,
332
                          its data format is specified by :attr:`data_format`.
X
xiaoting 已提交
333
        size (list|tuple|Tensor|None): Output shape of image resize
334 335
             layer, the shape is (out_w, ) when input is a 3-D Tensor, the shape is (out_h, out_w)
             when input is a 4-D Tensor and is (out_d, out_h, out_w) when input is a 5-D Tensor.
336
             Default: None. If a list/tuple, each element can be an integer or a Tensor of shape: [1].
X
xiaoting 已提交
337
             If a Tensor , its dimensions size should be a 1.
338 339 340
        scale_factor (float|Tensor|list|tuple|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`. Has to match input size if it is either a list or a tuple or a Tensor.
341
             Default: None.
342 343
        mode (str): The resample method. It supports 'linear', 'nearst', 'bilinear',
                       'bicubic' and 'trilinear' currently. Default: 'nearest'
344 345 346
        align_corners(bool) :  An optional bool, If True, the centers of the 4 corner pixels of the
                               input and output tensors are aligned, preserving the values at the
                               corner pixels.
347 348 349 350
                               Default: False
        align_mode(int)  :  An optional for linear/bilinear/trilinear interpolation. Refer to the formula in the example above,
                            it can be \'0\' for src_idx = scale_factor*(dst_indx+0.5)-0.5 , can be \'1\' for
                            src_idx = scale_factor*dst_index.
351
        data_format (str, optional): Specify the data format of the input, and the data format of the output
352
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
353 354 355
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
356 357 358
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
359 360 361
    Returns:
        A 3-D Tensor of the shape (num_batches, channels, out_w) or (num_batches, out_w, channels),
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),
362
        or 5-D Tensor of the shape (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels).
363 364 365

    Examples:
        .. code-block:: python
366

367
            import paddle
X
xiaoting 已提交
368

369 370
            input = paddle.rand([2,3,6,10], dtype="float32")
            upsample_out = paddle.nn.Upsample(size=[12,12])
X
xiaoting 已提交
371 372 373

            output = upsample_out(x=input)
            print(output.shape)
374
            # [2, 3, 12, 12]
X
xiaoting 已提交
375

376 377
    """

378 379 380 381 382 383 384 385 386 387
    def __init__(
        self,
        size=None,
        scale_factor=None,
        mode='nearest',
        align_corners=False,
        align_mode=0,
        data_format='NCHW',
        name=None,
    ):
388
        super().__init__()
389 390 391
        self.size = size
        self.scale_factor = scale_factor
        self.mode = mode.lower()
392 393 394
        self.align_corners = align_corners
        self.align_mode = align_mode
        self.data_format = data_format
X
xiaoting 已提交
395
        self.name = name
396

X
xiaoting 已提交
397
    def forward(self, x):
398 399 400 401 402 403 404 405 406 407
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale_factor,
            mode=self.mode,
            align_corners=self.align_corners,
            align_mode=self.align_mode,
            data_format=self.data_format,
            name=self.name,
        )
X
xiaoting 已提交
408 409 410

        return out

411 412 413 414 415 416 417
    def extra_repr(self):
        if self.scale_factor is not None:
            main_str = 'scale_factor={}'.format(self.scale_factor)
        else:
            main_str = 'size={}'.format(self.size)
        name_str = ', name={}'.format(self.name) if self.name else ''
        return '{}, mode={}, align_corners={}, align_mode={}, data_format={}{}'.format(
418 419 420 421 422 423 424
            main_str,
            self.mode,
            self.align_corners,
            self.align_mode,
            self.data_format,
            name_str,
        )
425

X
xiaoting 已提交
426

Z
zhiboniu 已提交
427
class UpsamplingNearest2D(Layer):
X
xiaoting 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
    """
    This op upsamples a batch of images, using nearest neighbours' pixel values.
    The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
    where in_w is width of the input tensor, in_h is the height of the input tensor.
    And the upsampling only applies on the two dimensions(height and width).
    Nearest neighbor interpolation is to perform nearest neighbor interpolation
    in both the 3rd dimension(in height direction) and the 4th dimension(in width
    direction) on input tensor.

    For details of nearest neighbor interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation.

    Parameters:
        x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
                          its data format is specified by :attr:`data_format`.
        size (list|tuple|Tensor|None): Output shape of image resize
             layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
445
             Default: None. If a list/tuple, each element can be an integer or a Tensor of shape: [1].
X
xiaoting 已提交
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
             If a Tensor , its dimensions size should be a 1.
        scale_factor (float|int|list|tuple|Tensor|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`.
             Has to match input size if it is either a list or a tuple or a Tensor.
             Default: None.
        data_format (str, optional): Specify the data format of the input, and the data format of the output
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
    Returns:
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),


    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
470
            input_data = paddle.rand(shape=(2,3,6,10)).astype("float32")
X
xiaoting 已提交
471 472 473 474 475 476 477
            upsample_out  = paddle.nn.UpsamplingNearest2D(size=[12,12])
            input = paddle.to_tensor(input_data)
            output = upsample_out(x=input)
            print(output.shape)
            # [2L, 3L, 12L, 12L]
    """

478 479 480
    def __init__(
        self, size=None, scale_factor=None, data_format='NCHW', name=None
    ):
481
        super().__init__()
X
xiaoting 已提交
482 483 484 485 486 487
        self.size = size
        self.scale_factor = scale_factor
        self.data_format = data_format
        self.name = name

    def forward(self, x):
488 489 490 491 492 493 494 495 496 497
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale_factor,
            mode='nearest',
            align_corners=False,
            align_mode=0,
            data_format=self.data_format,
            name=self.name,
        )
X
xiaoting 已提交
498 499 500

        return out

501 502 503 504 505 506
    def extra_repr(self):
        if self.scale_factor is not None:
            main_str = 'scale_factor={}'.format(self.scale_factor)
        else:
            main_str = 'size={}'.format(self.size)
        name_str = ', name={}'.format(self.name) if self.name else ''
507 508 509
        return '{}, data_format={}{}'.format(
            main_str, self.data_format, name_str
        )
510

X
xiaoting 已提交
511

Z
zhiboniu 已提交
512
class UpsamplingBilinear2D(Layer):
X
xiaoting 已提交
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
    """
    This op upsamples a batch of images, using bilinear' pixel values.
    The input must be a 4-D Tensor of the shape (num_batches, channels, in_h, in_w),
    where in_w is width of the input tensor, in_h is the height of the input tensor.
    And the upsampling only applies on the two dimensions(height and width).
    Bilinear interpolation is an extension of linear interpolation for
    interpolating functions of two variables (e.g. H-direction and
    W-direction in this op) on a rectilinear 2D grid. The key idea is
    to perform linear interpolation first in one direction, and then
    again in the other direction.

    For details of bilinear interpolation, please refer to Wikipedia:
    https://en.wikipedia.org/wiki/Bilinear_interpolation.

    Parameters:
        x (Tensor): 4-D Tensor, its data type is float32, float64, or uint8,
                          its data format is specified by :attr:`data_format`.
        size (list|tuple|Tensor|None): Output shape of image resize
             layer, the shape is (out_h, out_w) when input is a 4-D Tensor.
532
             Default: None. If a list/tuple, each element can be an integer or a Tensor  of shape: [1].
X
xiaoting 已提交
533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
             If a Tensor , its dimensions size should be a 1.
        scale_factor (float|int|list|tuple|Tensor|None): The multiplier for the input height or width. At
             least one of :attr:`size` or :attr:`scale_factor` must be set.
             And :attr:`size` has a higher priority than :attr:`scale_factor`.
             Has to match input size if it is either a list or a tuple or a Tensor.
             Default: None.
        data_format (str, optional): Specify the data format of the input, and the data format of the output
            will be consistent with that of the input. An optional string from:`NCW`, `NWC`, `"NCHW"`, `"NHWC"`, `"NCDHW"`,
            `"NDHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
            `[batch_size, input_channels, input_height, input_width]`. When it is `"NCHW"`, the data is stored
            in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`
    Returns:
        A 4-D Tensor of the shape (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels),

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
556
            input_data = paddle.rand(shape=(2,3,6,10)).astype("float32")
X
xiaoting 已提交
557 558 559 560 561 562 563
            upsample_out  = paddle.nn.UpsamplingBilinear2D(size=[12,12])
            input = paddle.to_tensor(input_data)
            output = upsample_out(x=input)
            print(output.shape)
            # [2L, 3L, 12L, 12L]
    """

564 565 566
    def __init__(
        self, size=None, scale_factor=None, data_format='NCHW', name=None
    ):
567
        super().__init__()
X
xiaoting 已提交
568 569 570 571 572 573
        self.size = size
        self.scale_factor = scale_factor
        self.data_format = data_format
        self.name = name

    def forward(self, x):
574 575 576 577 578 579 580 581 582 583
        out = F.interpolate(
            x,
            size=self.size,
            scale_factor=self.scale_factor,
            mode='bilinear',
            align_corners=True,
            align_mode=0,
            data_format=self.data_format,
            name=self.name,
        )
X
xiaoting 已提交
584 585 586

        return out

587 588 589 590 591 592
    def extra_repr(self):
        if self.scale_factor is not None:
            main_str = 'scale_factor={}'.format(self.scale_factor)
        else:
            main_str = 'size={}'.format(self.size)
        name_str = ', name={}'.format(self.name) if self.name else ''
593 594 595
        return '{}, data_format={}{}'.format(
            main_str, self.data_format, name_str
        )
596

X
xiaoting 已提交
597

Z
zhiboniu 已提交
598
class Bilinear(Layer):
599
    r"""
600 601 602 603

    This layer performs bilinear on two inputs.

    .. math::
604

605
      out_{i} = x1 * W_{i} * {x2^\mathrm{T}}, i=0,1,...,outfeatures-1
606

607 608 609 610 611 612
      out = out + b

    In this formula:
     - :math:`x1`: the first input contains in1_features elements, shape is [batch_size, in1_features].
     - :math:`x2`: the second input contains in2_features elements, shape is [batch_size, in2_features].
     - :math:`W_{i}`: the i-th learned weight, shape is [in1_features, in2_features], and learned weight's shape is [out_features, in1_features, in2_features].
613
     - :math:`out_{i}`: the i-th element of out, shape is [batch_size], and out's shape is [batch_size, out_features].
614 615 616 617 618 619 620
     - :math:`b`: the learned bias, shape is [1, out_features].
     - :math:`x2^\mathrm{T}`: the transpose of :math:`x2`.

    Parameters:
       in1_features (int): The dimension of each first input(`x1`).
       in2_features (int): The dimension of each second input(`x2`).
       out_features (int): The dimension of output of this layer.
T
tangwei12 已提交
621
       weight_attr (ParamAttr, optional): The parameter attribute for the learnable w, parameters/weights of
622 623 624
       this layer. The default value is None.
       bias_attr (ParamAttr, optional): The parameter attribute for the bias
           of this layer. If it is set to False, no bias will be added to the output units.
T
tangwei12 已提交
625
           If it is set to None, the bias is initialized zero. The default value is None.
626 627 628 629 630 631 632 633 634
       name (str, optional): The default value is None. Normally there is no need for user
           to set this property. For more information, please refer to :ref:`api_guide_Name`. Default: None.

    Attribute:
        **weight** (Parameter): the learnable weights of this layer.

        **bias** (Parameter): the learnable bias of this layer.

    Returns:
635
       Tensor: A 2-D Tensor of shape [batch_size, out_features].
636 637 638 639 640 641

    Examples:
       .. code-block:: python

        import paddle

642 643
        layer1 = paddle.rand((5, 5)).astype('float32')
        layer2 = paddle.rand((5, 4)).astype('float32')
644 645
        bilinear = paddle.nn.Bilinear(
            in1_features=5, in2_features=4, out_features=1000)
646
        result = bilinear(layer1,layer2)    # result shape [5, 1000]
647 648 649

    """

650 651 652 653 654 655 656 657 658
    def __init__(
        self,
        in1_features,
        in2_features,
        out_features,
        weight_attr=None,
        bias_attr=None,
        name=None,
    ):
659
        super().__init__()
660 661 662 663 664 665 666 667 668
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._name = name
        self._in1_features = in1_features
        self._in2_features = in2_features
        self._out_features = out_features
        self._dtype = self._helper.get_default_dtype()

        weight_shape = [
669 670 671
            self._out_features,
            self._in1_features,
            self._in2_features,
672
        ]
673 674 675 676 677 678
        self.weight = self.create_parameter(
            attr=self._weight_attr,
            shape=weight_shape,
            dtype=self._dtype,
            is_bias=False,
        )
679
        bias_shape = [1, self._out_features]
680 681 682 683 684 685
        self.bias = self.create_parameter(
            attr=self._bias_attr,
            shape=bias_shape,
            dtype=self._dtype,
            is_bias=True,
        )
686 687 688 689

    def forward(self, x1, x2):
        return F.bilinear(x1, x2, self.weight, self.bias, self._name)

690 691 692
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'in1_features={}, in2_features={}, out_features={}, dtype={}{}'.format(
693 694 695 696 697 698
            self._in1_features,
            self._in2_features,
            self._out_features,
            self._dtype,
            name_str,
        )
699

700

Z
zhiboniu 已提交
701
class Dropout(Layer):
702 703 704
    """
    Dropout is a regularization technique for reducing overfitting by preventing
    neuron co-adaption during training as described in the paper:
T
tangwei12 已提交
705
    `Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/abs/1207.0580>`_
706 707 708 709
    The dropout operator randomly sets the outputs of some units to zero, while upscale others
    according to the given dropout probability.

    See ``paddle.nn.functional.dropout`` for more details.
710 711

    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
712 713

    Parameters:
714 715
        p (float|int): Probability of setting units to zero. Default: 0.5
        axis (int|list|tuple): The axis along which the dropout is performed. Default None.
716 717 718 719 720 721 722 723 724 725 726
        mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']

                               1. upscale_in_train(default), upscale the output at training time

                                  - train: out = input * mask / ( 1.0 - p )
                                  - inference: out = input

                               2. downscale_in_infer, downscale the output at inference

                                  - train: out = input * mask
                                  - inference: out = input * (1.0 - p)
727
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
728 729 730 731 732

    Shape:
        - input: N-D tensor.
        - output: N-D tensor, the same shape as input.

733

734 735
    Examples:
        .. code-block:: python
736

737 738
            import paddle

739
            x = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32")
740
            m = paddle.nn.Dropout(p=0.5)
741

742
            y_train = m(x)
743 744 745 746 747
            print(y_train)
            # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[2., 0., 6.],
            #         [0., 0., 0.]])

748 749
            m.eval()  # switch the model to test phase
            y_test = m(x)
750
            print(y_test)
751 752 753
            # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[1., 2., 3.],
            #         [4., 5., 6.]])
754
    """
755 756

    def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None):
757
        super().__init__()
758 759 760 761 762 763 764

        self.p = p
        self.axis = axis
        self.mode = mode
        self.name = name

    def forward(self, input):
765 766 767 768 769 770 771 772
        out = F.dropout(
            input,
            p=self.p,
            axis=self.axis,
            training=self.training,
            mode=self.mode,
            name=self.name,
        )
773 774
        return out

775 776
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
777 778 779
        return 'p={}, axis={}, mode={}{}'.format(
            self.p, self.axis, self.mode, name_str
        )
780

781

Z
zhiboniu 已提交
782
class Dropout2D(Layer):
783 784 785 786
    """
    Randomly zero out entire channels (in the batched input 4d tensor with the shape `NCHW` ,
    a channel is a 2D feature map with the shape `HW`). Each channel will be zeroed out independently
    on every forward call with probability `p` using samples from a Bernoulli distribution.
C
cnn 已提交
787
    Dropout2D will help promote independence between feature maps as described in the paper:
T
tangwei12 已提交
788
    `Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
789 790 791

    See ``paddle.nn.functional.dropout2d`` for more details.

792 793
    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

794 795
    Parameters:
        p (float, optional): Probability of setting units to zero. Default: 0.5
796
        data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from `NCHW` or `NHWC`. The default is `NCHW`. When it is `NCHW`, the data is stored in the order of: [batch_size, input_channels, input_height, input_width].
797 798 799 800 801 802
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input: 4-D tensor.
        - output: 4-D tensor, the same shape as input.

803

804 805
    Examples:
        .. code-block:: python
806

807 808
            import paddle

809 810 811 812 813 814 815 816 817
            x = paddle.rand([2, 2, 1, 3], dtype="float32")
            print(x)
            # Tensor(shape=[2, 2, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[0.10052059, 0.93890846, 0.45351565]],
            #          [[0.47507706, 0.45021373, 0.11331241]]],

            #         [[[0.53358698, 0.97375143, 0.34997326]],
            #          [[0.24758087, 0.52628899, 0.17970420]]]])

C
cnn 已提交
818
            m = paddle.nn.Dropout2D(p=0.5)
819
            y_train = m(x)
820 821 822 823 824 825 826 827
            print(y_train)
            # Tensor(shape=[2, 2, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[0.        , 0.        , 0.        ]],
            #          [[0.95015413, 0.90042746, 0.22662482]]],

            #         [[[1.06717396, 1.94750285, 0.69994652]],
            #          [[0.        , 0.        , 0.        ]]]])

828 829
            m.eval()  # switch the model to test phase
            y_test = m(x)
830
            print(y_test)
831 832 833 834 835 836
            # Tensor(shape=[2, 2, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[0.10052059, 0.93890846, 0.45351565]],
            #          [[0.47507706, 0.45021373, 0.11331241]]],

            #         [[[0.53358698, 0.97375143, 0.34997326]],
            #          [[0.24758087, 0.52628899, 0.17970420]]]])
837
    """
838 839

    def __init__(self, p=0.5, data_format='NCHW', name=None):
840
        super().__init__()
841 842 843 844 845 846

        self.p = p
        self.data_format = data_format
        self.name = name

    def forward(self, input):
847 848 849 850 851 852 853
        out = F.dropout2d(
            input,
            p=self.p,
            training=self.training,
            data_format=self.data_format,
            name=self.name,
        )
854 855
        return out

856 857
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
858 859 860
        return 'p={}, data_format={}{}'.format(
            self.p, self.data_format, name_str
        )
861

862

Z
zhiboniu 已提交
863
class Dropout3D(Layer):
864 865 866 867
    """
    Randomly zero out entire channels (in the batched input 5d tensor with the shape `NCDHW` ,
    a channel is a 3D feature map with the shape `DHW` ). Each channel will be zeroed out independently
    on every forward call with probability `p` using samples from a Bernoulli distribution.
C
cnn 已提交
868
    Dropout3D will help promote independence between feature maps as described in the paper:
T
tangwei12 已提交
869
    `Efficient Object Localization Using Convolutional Networks <https://arxiv.org/abs/1411.4280>`_
870 871 872

    See ``paddle.nn.functional.dropout3d`` for more details.

873 874
    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

875 876
    Parameters:
        p (float | int): Probability of setting units to zero. Default: 0.5
877
        data_format (str, optional): Specify the data format of the input, and the data format of the output will be consistent with that of the input. An optional string from `NCDHW` or `NDHWC`. The default is `NCDHW`. When it is `NCDHW`, the data is stored in the order of: [batch_size, input_channels, input_depth, input_height, input_width].
878 879 880 881 882 883
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input: 5-D tensor.
        - output: 5-D tensor, the same shape as input.

884

885 886
    Examples:
        .. code-block:: python
887

888 889
            import paddle

890 891 892 893 894 895 896 897 898 899 900 901 902
            x = paddle.arange(24, dtype="float32").reshape((1, 2, 2, 2, 3))
            print(x)
            # Tensor(shape=[1, 2, 2, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[[0. , 1. , 2. ],
            #            [3. , 4. , 5. ]],
            #           [[6. , 7. , 8. ],
            #            [9. , 10., 11.]]],

            #          [[[12., 13., 14.],
            #            [15., 16., 17.]],
            #           [[18., 19., 20.],
            #            [21., 22., 23.]]]]])

C
cnn 已提交
903
            m = paddle.nn.Dropout3D(p=0.5)
904
            y_train = m(x)
905 906 907 908 909 910 911 912 913 914 915 916
            print(y_train)
            # Tensor(shape=[1, 2, 2, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[[0. , 2. , 4. ],
            #            [6. , 8. , 10.]],
            #           [[12., 14., 16.],
            #            [18., 20., 22.]]],

            #          [[[0. , 0. , 0. ],
            #            [0. , 0. , 0. ]],
            #           [[0. , 0. , 0. ],
            #            [0. , 0. , 0. ]]]]])

917 918
            m.eval()  # switch the model to test phase
            y_test = m(x)
919
            print(y_test)
920 921 922 923 924 925 926 927 928 929
            # Tensor(shape=[1, 2, 2, 2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[[[[0. , 1. , 2. ],
            #            [3. , 4. , 5. ]],
            #           [[6. , 7. , 8. ],
            #            [9. , 10., 11.]]],

            #          [[[12., 13., 14.],
            #            [15., 16., 17.]],
            #           [[18., 19., 20.],
            #            [21., 22., 23.]]]]])
930
    """
931 932

    def __init__(self, p=0.5, data_format='NCDHW', name=None):
933
        super().__init__()
934 935 936 937 938 939

        self.p = p
        self.data_format = data_format
        self.name = name

    def forward(self, input):
940 941 942 943 944 945 946
        out = F.dropout3d(
            input,
            p=self.p,
            training=self.training,
            data_format=self.data_format,
            name=self.name,
        )
947 948
        return out

949 950
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
951 952 953
        return 'p={}, data_format={}{}'.format(
            self.p, self.data_format, name_str
        )
954

955

Z
zhiboniu 已提交
956
class AlphaDropout(Layer):
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977
    """
    Alpha Dropout is a type of Dropout that maintains the self-normalizing property. For an input with
    zero mean and unit standard deviation, the output of Alpha Dropout maintains the original mean and
    standard deviation of the input. Alpha Dropout fits well to SELU activate function by randomly setting
    activations to the negative saturation value.

    For more information, please refer to:
    `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_

    In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.

    Parameters:
        p (float | int): Probability of setting units to zero. Default: 0.5
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - input: N-D tensor.
        - output: N-D tensor, the same shape as input.

    Examples:
        .. code-block:: python
978

979 980
            import paddle

981
            x = paddle.to_tensor([[-1, 1], [-1, 1]], dtype="float32")
982 983
            m = paddle.nn.AlphaDropout(p=0.5)
            y_train = m(x)
984 985 986 987 988
            print(y_train)
            # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[-0.77919382,  1.66559887],
            #         [-0.77919382, -0.77919382]])

989 990
            m.eval()  # switch the model to test phase
            y_test = m(x)
991
            print(y_test)
992 993 994
            # Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [[-1.,  1.],
            #         [-1.,  1.]])
995
    """
996 997

    def __init__(self, p=0.5, name=None):
998
        super().__init__()
999 1000 1001 1002
        self.p = p
        self.name = name

    def forward(self, input):
1003 1004 1005
        out = F.alpha_dropout(
            input, p=self.p, training=self.training, name=self.name
        )
1006 1007
        return out

1008 1009 1010 1011
    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
        return 'p={}{}'.format(self.p, name_str)

1012

Z
zhiboniu 已提交
1013
class Pad1D(Layer):
L
littletomatodonkey 已提交
1014
    """
L
littletomatodonkey 已提交
1015 1016 1017
    This interface is used to construct a callable object of the ``Pad1D`` class.
    Pad tensor according to 'pad', 'mode' and 'value'.
    If mode is 'reflect', pad[0] and pad[1] must be no greater than width-1.
L
littletomatodonkey 已提交
1018 1019

    Parameters:
1020
        padding (Tensor|list[int]|int): The padding size with data type int. If is int, use the
1021
            same padding in both dimensions. Else [len(padding)/2] dimensions
L
littletomatodonkey 已提交
1022
            of input will be padded. The pad has the form (pad_left, pad_right).
1023 1024 1025 1026 1027 1028 1029 1030 1031
        mode (str, optional): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. Default is 'constant'.

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`。
        data_format (str, optional): An string from: "NCL", "NLC". Specify the data format of the input data.
L
littletomatodonkey 已提交
1032
           Default is  "NCL"
1033
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1034 1035

    Returns:
L
littletomatodonkey 已提交
1036 1037 1038 1039
        None

    Examples:
        .. code-block:: python
1040

L
littletomatodonkey 已提交
1041 1042 1043 1044 1045
            import paddle
            import paddle.nn as nn

            input_shape = (1, 2, 3)
            pad = [1, 2]
L
littletomatodonkey 已提交
1046
            mode = "constant"
1047
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1048
            my_pad = nn.Pad1D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1049
            result = my_pad(data)
L
littletomatodonkey 已提交
1050
            print(result)
L
littletomatodonkey 已提交
1051 1052 1053 1054
            # [[[0. 1. 2. 3. 0. 0.]
            #   [0. 4. 5. 6. 0. 0.]]]
    """

1055 1056 1057
    def __init__(
        self, padding, mode='constant', value=0.0, data_format="NCL", name=None
    ):
1058
        super().__init__()
1059
        self._pad = _npairs(padding, 1)
L
littletomatodonkey 已提交
1060
        self._mode = mode
L
littletomatodonkey 已提交
1061
        self._value = value
L
littletomatodonkey 已提交
1062
        self._data_format = data_format
L
littletomatodonkey 已提交
1063 1064 1065
        self._name = name

    def forward(self, x):
1066 1067 1068 1069 1070 1071 1072 1073
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
L
littletomatodonkey 已提交
1074

1075 1076 1077
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
1078 1079
            self._pad, self._mode, self._value, self._data_format, name_str
        )
1080

L
littletomatodonkey 已提交
1081

Z
zhiboniu 已提交
1082
class Pad2D(Layer):
L
littletomatodonkey 已提交
1083
    """
L
littletomatodonkey 已提交
1084 1085 1086 1087
    This interface is used to construct a callable object of the ``Pad2D`` class.
    Pad tensor according to 'pad', 'mode' and 'value'.
    If mode is 'reflect', pad[0] and pad[1] must be no greater
    than width-1. The height dimension has the same condition.
L
littletomatodonkey 已提交
1088 1089

    Parameters:
1090
        padding (Tensor|list[int]|int): The padding size with data type int. If is int, use the
1091 1092
            same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
            The pad has the form (pad_left, pad_right, pad_top, pad_bottom).
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
        mode (str, optional): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. Default is 'constant'.

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`。
        data_format (str, optional): An string from: "NCHW", "NHWC". Specify the data format of the input data.
           Default is  "NCHW"。
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1104 1105

    Returns:
L
littletomatodonkey 已提交
1106 1107 1108 1109
        None

    Examples:
        .. code-block:: python
1110

L
littletomatodonkey 已提交
1111 1112
            import paddle
            import paddle.nn as nn
1113

L
littletomatodonkey 已提交
1114 1115
            input_shape = (1, 1, 2, 3)
            pad = [1, 0, 1, 2]
L
littletomatodonkey 已提交
1116
            mode = "constant"
1117
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1118
            my_pad = nn.Pad2D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1119
            result = my_pad(data)
L
littletomatodonkey 已提交
1120
            print(result)
L
littletomatodonkey 已提交
1121 1122 1123 1124 1125 1126 1127
            # [[[[0. 0. 0. 0.]
            #    [0. 1. 2. 3.]
            #    [0. 4. 5. 6.]
            #    [0. 0. 0. 0.]
            #    [0. 0. 0. 0.]]]]
    """

1128 1129 1130
    def __init__(
        self, padding, mode='constant', value=0.0, data_format="NCHW", name=None
    ):
1131
        super().__init__()
1132
        self._pad = _npairs(padding, 2)
L
littletomatodonkey 已提交
1133
        self._mode = mode
L
littletomatodonkey 已提交
1134 1135 1136 1137 1138
        self._value = value
        self._data_format = data_format
        self._name = name

    def forward(self, x):
1139 1140 1141 1142 1143 1144 1145 1146
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
L
littletomatodonkey 已提交
1147

1148 1149 1150
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
1151 1152
            self._pad, self._mode, self._value, self._data_format, name_str
        )
1153

L
littletomatodonkey 已提交
1154

1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199
class ZeroPad2D(Layer):
    """
    This interface is used to construct a callable object of the ``ZeroPad2D`` class.
    Pads the input tensor boundaries with zero.

    Parameters:
        padding (Tensor | List[int] | int): The padding size with data type int. If is int, use the
            same padding in all dimensions. Else [len(padding)/2] dimensions of input will be padded.
            The pad has the form (pad_left, pad_right, pad_top, pad_bottom).
        data_format (str): An string from: "NCHW", "NHWC". Specify the data format of the input data.
           Default is  "NCHW"
        name (str, optional) : The default value is None.  Normally there is no need for
            user to set this property.  For more information, please refer to :ref:`api_guide_Name`.

    Shape:
        - x(Tensor): The input tensor of zeropad2d operator, which is a 4-D tensor.
          The data type can be float32, float64.
        - output(Tensor): The output tensor of zeropad2d operator, which is a 4-D tensor.
          The data type is same as input x.

    Examples:
        Examples are as follows.

        .. code-block:: python

            import paddle
            import paddle.nn as nn
            import numpy as np

            input_shape = (1, 1, 2, 3)
            pad = [1, 0, 1, 2]
            data = paddle.arange(np.prod(input_shape), dtype="float32").reshape(input_shape) + 1

            my_pad = nn.ZeroPad2D(padding=pad)
            result = my_pad(data)

            print(result)
            # [[[[0. 0. 0. 0.]
            #    [0. 1. 2. 3.]
            #    [0. 4. 5. 6.]
            #    [0. 0. 0. 0.]
            #    [0. 0. 0. 0.]]]]
    """

    def __init__(self, padding, data_format="NCHW", name=None):
1200
        super().__init__()
1201 1202
        self._pad = _npairs(padding, 2)
        self._mode = 'constant'
1203
        self._value = 0.0
1204 1205 1206 1207
        self._data_format = data_format
        self._name = name

    def forward(self, x):
1208 1209 1210 1211 1212 1213 1214 1215
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
1216 1217 1218

    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
1219 1220 1221
        return 'padding={}, data_format={}{}'.format(
            self._pad, self._data_format, name_str
        )
1222 1223


Z
zhiboniu 已提交
1224
class Pad3D(Layer):
L
littletomatodonkey 已提交
1225
    """
L
littletomatodonkey 已提交
1226 1227 1228 1229
    This interface is used to construct a callable object of the ``Pad3D`` class.
    Pad tensor according to 'pad', 'mode' and 'value'.
    If mode is 'reflect', pad[0] and pad[1] must be no greater
    than width-1. The height and depth dimension has the same condition.
L
littletomatodonkey 已提交
1230 1231

    Parameters:
1232
        padding (Tensor|list[int]|int): The padding size with data type int. If is int, use the
1233
            same padding in all dimensions. Else [len(padding)/2] dimensions
L
littletomatodonkey 已提交
1234
            of input will be padded. The pad has the form (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back).
1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
        mode (str, optional): Four modes: 'constant' (default), 'reflect', 'replicate', 'circular'. Default is 'constant'.

           - 'constant' mode, uses a constant value to pad the input tensor.
           - 'reflect' mode, uses reflection of the input boundaries to pad the input tensor.
           - 'replicate' mode, uses input boundaries to pad the input tensor.
           - 'circular' mode, uses circular input to pad the input tensor.

        value (float, optional): The value to fill the padded areas. Default is :math:`0.0`。
        data_format (str, optional): An string from: "NCDHW", "NDHWC". Specify the data format of the input data.
           Default is  "NCDHW"。
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
1246 1247

    Returns:
L
littletomatodonkey 已提交
1248 1249 1250 1251
        None

    Examples:
        .. code-block:: python
1252

L
littletomatodonkey 已提交
1253 1254
            import paddle
            import paddle.nn as nn
1255

L
littletomatodonkey 已提交
1256 1257
            input_shape = (1, 1, 1, 2, 3)
            pad = [1, 0, 1, 2, 0, 0]
L
littletomatodonkey 已提交
1258
            mode = "constant"
1259
            data = paddle.arange(paddle.prod(paddle.to_tensor(input_shape)), dtype="float32").reshape(input_shape) + 1
L
littletomatodonkey 已提交
1260
            my_pad = nn.Pad3D(padding=pad, mode=mode)
L
littletomatodonkey 已提交
1261
            result = my_pad(data)
L
littletomatodonkey 已提交
1262
            print(result)
L
littletomatodonkey 已提交
1263 1264 1265 1266 1267 1268 1269
            # [[[[[0. 0. 0. 0.]
            #     [0. 1. 2. 3.]
            #     [0. 4. 5. 6.]
            #     [0. 0. 0. 0.]
            #     [0. 0. 0. 0.]]]]]
    """

1270 1271 1272 1273 1274 1275 1276 1277
    def __init__(
        self,
        padding,
        mode='constant',
        value=0.0,
        data_format="NCDHW",
        name=None,
    ):
1278
        super().__init__()
1279
        self._pad = _npairs(padding, 3)
L
littletomatodonkey 已提交
1280
        self._mode = mode
L
littletomatodonkey 已提交
1281 1282 1283 1284 1285
        self._value = value
        self._data_format = data_format
        self._name = name

    def forward(self, x):
1286 1287 1288 1289 1290 1291 1292 1293
        return F.pad(
            x,
            pad=self._pad,
            mode=self._mode,
            value=self._value,
            data_format=self._data_format,
            name=self._name,
        )
L
littletomatodonkey 已提交
1294

1295 1296 1297
    def extra_repr(self):
        name_str = ', name={}'.format(self._name) if self._name else ''
        return 'padding={}, mode={}, value={}, data_format={}{}'.format(
1298 1299
            self._pad, self._mode, self._value, self._data_format, name_str
        )
1300

L
littletomatodonkey 已提交
1301

Z
zhiboniu 已提交
1302
class CosineSimilarity(Layer):
L
littletomatodonkey 已提交
1303
    """
1304
    This interface is used to compute cosine similarity between x1 and x2 along axis.
L
littletomatodonkey 已提交
1305 1306

    Parameters:
1307
        axis (int): Dimension of vectors to compute cosine similarity. Default is 1.
L
littletomatodonkey 已提交
1308
        eps(float): Small value to avoid division by zero. Default is 1e-8.
1309
    Returns:
L
littletomatodonkey 已提交
1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
        None

    Examples:
        .. code-block:: text

            Case 0:
                x1 = [[0.8024077  0.9927354  0.27238318 0.8344984 ]
                     [0.48949873 0.5797396  0.65444374 0.66510963]
                     [0.1031398  0.9614342  0.08365563 0.6796464 ]
                     [0.10760343 0.7461209  0.7726148  0.5801006 ]]
                x2 = [[0.62913156 0.1536727  0.9847992  0.04591406]
                     [0.9098952  0.15715368 0.8671125  0.3156102 ]
                     [0.4427798  0.54136837 0.5276275  0.32394758]
                     [0.3769419  0.8535014  0.48041078 0.9256797 ]]
1324
                axis = 1
L
littletomatodonkey 已提交
1325 1326 1327 1328 1329
                eps = 1e-8
                Out: [0.5275037  0.8368967  0.75037485 0.9245899]

    Code Examples:
        .. code-block:: python
1330

L
littletomatodonkey 已提交
1331 1332 1333
            import paddle
            import paddle.nn as nn

1334 1335 1336 1337
            x1 = paddle.to_tensor([[1., 2., 3.],
                                [2., 3., 4.]], dtype="float32")
            x2 = paddle.to_tensor([[8., 3., 3.],
                                [2., 3., 4.]], dtype="float32")
L
littletomatodonkey 已提交
1338

1339
            cos_sim_func = nn.CosineSimilarity(axis=0)
L
littletomatodonkey 已提交
1340
            result = cos_sim_func(x1, x2)
L
littletomatodonkey 已提交
1341
            print(result)
1342 1343
            # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
            #        [0.65079135, 0.98058069, 1.        ])
L
littletomatodonkey 已提交
1344 1345
    """

1346
    def __init__(self, axis=1, eps=1e-8):
1347
        super().__init__()
1348
        self._axis = axis
L
littletomatodonkey 已提交
1349 1350 1351
        self._eps = eps

    def forward(self, x1, x2):
1352
        return F.cosine_similarity(x1, x2, axis=self._axis, eps=self._eps)
T
tangwei12 已提交
1353

1354 1355 1356
    def extra_repr(self):
        return 'axis={_axis}, eps={_eps}'.format(**self.__dict__)

T
tangwei12 已提交
1357

Z
zhiboniu 已提交
1358
class Embedding(Layer):
1359
    r"""
1360

1361
    Embedding Layer, used to construct a callable object of the ``Embedding`` class.
T
tangwei12 已提交
1362
    For specific usage, refer to code examples. It implements the function of the Embedding Layer.
T
tangwei12 已提交
1363
    This layer is used to lookup embeddings vector of ids provided by :attr:`x` .
T
tangwei12 已提交
1364
    It automatically constructs a 2D embedding matrix based on the
T
tangwei12 已提交
1365
    input :attr:`num_embeddings` and :attr:`embedding_dim`.
T
tangwei12 已提交
1366 1367 1368 1369

    The shape of output Tensor is generated by appending an emb_size dimension to the
    last dimension of the input Tensor shape.

1370 1371 1372
    Note:
        The id in :attr:`x` must satisfy :math:`0 =< id < num_embeddings` ,
        otherwise the program will throw an exception and exit.
T
tangwei12 已提交
1373 1374 1375 1376 1377

    .. code-block:: text

        Case 1:

T
tangwei12 已提交
1378 1379 1380
        x is a Tensor. padding_idx = -1
            x.data = [[1, 3], [2, 4], [4, 127]
            x.shape = [3, 2]
T
tangwei12 已提交
1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397
        Given size = [128, 16]
        output is a Tensor:
            out.shape = [3, 2, 16]
            out.data = [[[0.129435295, 0.244512452, ..., 0.436322452],
                        [0.345421456, 0.524563927, ..., 0.144534654]],

                        [[0.345249859, 0.124939536, ..., 0.194353745],
                        [0.945345345, 0.435394634, ..., 0.435345365]],

                        [[0.945345345, 0.435394634, ..., 0.435345365],
                        [0.0,         0.0,         ..., 0.0        ]]]  # padding data
        The input padding_idx is less than 0, it is automatically converted to padding_idx = -1 + 128 = 127
        It will pad all-zero data when ids is 127.

    Parameters:
        num_embeddings (int): Just one element which indicate the size
            of the dictionary of embeddings.
T
tangwei12 已提交
1398
        embedding_dim (int):  Just one element which indicate the size of each embedding vector respectively.
1399
        padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-num_embeddings, num_embeddings).
T
tangwei12 已提交
1400 1401 1402 1403
            If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
            to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
            encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
            If set None, it makes no effect to output. Default: None.
1404
        sparse(bool, optional): The flag indicating whether to use sparse update. This parameter only
T
tangwei12 已提交
1405 1406
            affects the performance of the backwards gradient update. It is recommended to set
            True because sparse update is faster. But some optimizer does not support sparse update,
T
tangwei12 已提交
1407
            such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`.
T
tangwei12 已提交
1408
            In these case, sparse must be False. Default: False.
1409
        weight_attr(ParamAttr, optional): To specify the weight parameter property. Default: None, which means the
T
tangwei12 已提交
1410
            default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
T
tangwei12 已提交
1411 1412
            user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
            The local word vector needs to be transformed into numpy format, and the shape of local word
T
tangwei12 已提交
1413 1414
            vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_initializer_NumpyArrayInitializer`
            is used to load custom or pre-trained word vectors. See code example for details.
1415
        name(str|None, optional): For detailed information, please refer
T
tangwei12 已提交
1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428
               to :ref:`api_guide_Name`. Usually name is no need to set and
               None by default.

    Attribute:
        **weight** (Parameter): the learnable weights of this layer.

    Returns:
        None

    Examples:

        .. code-block:: python

T
tangwei12 已提交
1429 1430
            import paddle

1431 1432
            x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False)
            embedding = paddle.nn.Embedding(4, 3, sparse=True)
T
tangwei12 已提交
1433

1434 1435 1436 1437
            w0 = paddle.to_tensor([[0., 0., 0.],
                                [1., 1., 1.],
                                [2., 2., 2.],
                                [3., 3., 3.]], dtype="float32")
T
tangwei12 已提交
1438
            embedding.weight.set_value(w0)
1439 1440 1441 1442 1443 1444
            print(embedding.weight)
            # Tensor(shape=[4, 3], dtype=float32, place=Place(gpu:0), stop_gradient=False,
            #        [[0., 0., 0.],
            #         [1., 1., 1.],
            #         [2., 2., 2.],
            #         [3., 3., 3.]])
T
tangwei12 已提交
1445

T
tangwei12 已提交
1446 1447 1448 1449
            adam = paddle.optimizer.Adam(parameters=[embedding.weight], learning_rate=0.01)
            adam.clear_grad()


1450 1451 1452 1453 1454 1455
            out = embedding(x)
            print(out)
            # Tensor(shape=[3, 1, 3], dtype=float32, place=Place(gpu:0), stop_gradient=False,
            #        [[[0., 0., 0.]],
            #         [[1., 1., 1.]],
            #         [[3., 3., 3.]]])
T
tangwei12 已提交
1456 1457 1458

            out.backward()
            adam.step()
T
tangwei12 已提交
1459 1460 1461

    """

1462 1463 1464 1465 1466 1467 1468 1469 1470
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        padding_idx=None,
        sparse=False,
        weight_attr=None,
        name=None,
    ):
1471
        super().__init__()
T
tangwei12 已提交
1472 1473 1474 1475
        self._num_embeddings = num_embeddings
        self._embedding_dim = embedding_dim
        self._sparse = sparse
        self._is_distributed = False
1476
        self._padding_idx = padding_idx
T
tangwei12 已提交
1477 1478 1479 1480 1481 1482 1483

        if self._num_embeddings <= 0:
            raise ValueError("num_embeddings must be gather than 0")

        if self._embedding_dim <= 0:
            raise ValueError("embedding_dim must be gather than 0")

1484 1485 1486 1487 1488 1489 1490
        padding_idx = (
            -1
            if padding_idx is None
            else padding_idx
            if padding_idx >= 0
            else (num_embeddings + padding_idx)
        )
1491 1492

        if padding_idx >= num_embeddings or padding_idx < -num_embeddings:
1493 1494 1495 1496 1497
            raise ValueError(
                "padding_idx must be within [-{}, {})".format(
                    num_embeddings, num_embeddings
                )
            )
T
tangwei12 已提交
1498

T
tangwei12 已提交
1499 1500 1501 1502 1503 1504
        self._dtype = self._helper.get_default_dtype()
        self._size = [self._num_embeddings, self._embedding_dim]

        self._weight_attr = weight_attr
        self._remote_prefetch = False
        self._name = name
1505 1506 1507 1508 1509 1510
        self.weight = self.create_parameter(
            attr=self._weight_attr,
            shape=self._size,
            dtype=self._dtype,
            is_bias=False,
        )
T
tangwei12 已提交
1511

Z
zhiboniu 已提交
1512
        if in_dynamic_mode() and padding_idx != -1:
1513 1514
            with paddle.no_grad():
                self.weight[padding_idx] = 0.0
T
tangwei12 已提交
1515

T
tangwei12 已提交
1516
    def forward(self, x):
1517 1518 1519 1520 1521 1522 1523
        return F.embedding(
            x,
            weight=self.weight,
            padding_idx=self._padding_idx,
            sparse=self._sparse,
            name=self._name,
        )
1524 1525 1526 1527 1528 1529 1530 1531 1532

    def extra_repr(self):
        main_str = '{_num_embeddings}, {_embedding_dim}'
        if self._padding_idx is not None:
            main_str += ', padding_idx={_padding_idx}'
        main_str += ', sparse={_sparse}'
        if self._name is not None:
            main_str += ', name={_name}'
        return main_str.format(**self.__dict__)
F
FNRE 已提交
1533 1534


Z
zhiboniu 已提交
1535
class Unfold(Layer):
F
FNRE 已提交
1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546
    """
    This op returns a col buffer of sliding local blocks of input x, also known
    as im2col for batched 2D image tensors. For each block under the convolution filter,
    all element will be rearranged as a column. While the convolution filter sliding over
    the input feature map, a series of such columns will be formed.

    For each input :math:`x` with shape [N, C, H, W], the output shape [N, Cout, Lout]
    can be calculated as following.

    See ``paddle.nn.functional.unfold`` for more details.

1547

F
FNRE 已提交
1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578
    Parameters:
        kernel_sizes(int|list):   The size of convolution kernel, should be [k_h, k_w]
                                  or an integer k treated as [k, k].
        strides(int|list):        The strides, should be [stride_h, stride_w]
                                  or an integer stride treated as [sride, stride].
                                  For default, strides will be [1, 1].
        paddings(int|list):       The paddings of each dimension, should be
                                  [padding_top, padding_left, padding_bottom, padding_right]
                                  or [padding_h, padding_w] or an integer padding.
                                  If [padding_h, padding_w] was given, it will expanded to
                                  [padding_h, padding_w, padding_h, padding_w]. If an integer
                                  padding was given, [padding, padding, padding, padding] will
                                  be used. For default, paddings will be [0, 0, 0, 0]
        dilations(int|list):      the dilations of convolution kernel, should be
                                  [dilation_h, dilation_w], or an integer dilation treated as
                                  [dilation, dilation]. For default, it will be [1, 1].
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`


    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn as nn

            x = paddle.randn((100,3,224,224))
            unfold = nn.Unfold(kernel_sizes=[3, 3])
            result = unfold(x)
            print(result)
X
xiaoting 已提交
1579
    """
F
FNRE 已提交
1580

1581 1582 1583
    def __init__(
        self, kernel_sizes, dilations=1, paddings=0, strides=1, name=None
    ):
1584
        super().__init__()
F
FNRE 已提交
1585 1586 1587 1588 1589 1590 1591 1592

        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.paddings = paddings
        self.strides = strides
        self.name = name

    def forward(self, input):
1593 1594 1595 1596 1597 1598 1599 1600
        return F.unfold(
            input,
            kernel_sizes=self.kernel_sizes,
            strides=self.strides,
            paddings=self.paddings,
            dilations=self.dilations,
            name=self.name,
        )
F
FNRE 已提交
1601 1602 1603

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
1604 1605 1606 1607 1608 1609 1610
        return 'kernel_size={}, dilation={}, padding={}, stride={}{}'.format(
            self.kernel_sizes,
            self.dilations,
            self.paddings,
            self.strides,
            name_str,
        )
X
xiaoting 已提交
1611 1612 1613


class Fold(Layer):
1614
    r"""
X
xiaoting 已提交
1615

1616
    Combines an array of sliding local blocks into a large containing
1617 1618
    tensor. also known as col2im when operated on batched 2D image tensor. Fold calculates each
    combined value in the resulting large tensor by summing all values from all containing blocks.
X
xiaoting 已提交
1619 1620 1621 1622 1623 1624


    For each input :math:`x` with shape [N, C_in , L], the output shape [N, C_out, H_out, W_out]
    can be calculated as following.

    .. math::
1625

1626 1627 1628
        H_{out} &= output\_size[0] \\
        W_{out} &= output\_size[1] \\
        C_{out} &= \frac{C_{in}}{kernel\_sizes[0]\times kernel\_sizes[1]} \\
X
xiaoting 已提交
1629 1630 1631 1632

    Parameters:
        output_sizes(list):       The size of output size, should be [output_size_h, output_size_w]
                                  or an interger o treated as [o, o].
X
xiaoting 已提交
1633
        kernel_sizes(int|list|tuple):   The size of convolution kernel, should be [k_h, k_w]
X
xiaoting 已提交
1634
                                  or an integer k treated as [k, k].
1635
        strides(int|list|tuple, optional):        The strides, should be [stride_h, stride_w]
X
xiaoting 已提交
1636 1637
                                  or an integer stride treated as [sride, stride].
                                  For default, strides will be [1, 1].
1638
        paddings(int|list|tuple, optional):       The paddings of each dimension, should be
X
xiaoting 已提交
1639 1640 1641 1642 1643 1644
                                  [padding_top, padding_left, padding_bottom, padding_right]
                                  or [padding_h, padding_w] or an integer padding.
                                  If [padding_h, padding_w] was given, it will expanded to
                                  [padding_h, padding_w, padding_h, padding_w]. If an integer
                                  padding was given, [padding, padding, padding, padding] will
                                  be used. For default, paddings will be [0, 0, 0, 0]
1645
        dilations(int|list|tuple, optional):      the dilations of convolution kernel, should be
X
xiaoting 已提交
1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663
                                  [dilation_h, dilation_w], or an integer dilation treated as
                                  [dilation, dilation]. For default, it will be [1, 1].
        name(str, optional): The default value is None.
                             Normally there is no need for user to set this property.
                             For more information, please refer to :ref:`api_guide_Name`


    Returns:
        The tensor formed by combining a group of sliding local blocks
        The output shape is [N, Cout, H, W] as decriabled above.

    Examples:

        .. code-block:: python

            import paddle
            import paddle.nn as nn

X
xiaoting 已提交
1664 1665
            x = paddle.randn([2,3*2*2,12])
            fold = nn.Fold(output_sizes=[4, 5], kernel_sizes=2)
X
xiaoting 已提交
1666
            y = fold(x)
X
xiaoting 已提交
1667
            # y.shape = [2,3,4,5]
X
xiaoting 已提交
1668 1669
   """

1670 1671 1672 1673 1674 1675 1676 1677 1678
    def __init__(
        self,
        output_sizes,
        kernel_sizes,
        dilations=1,
        paddings=0,
        strides=1,
        name=None,
    ):
1679
        super().__init__()
X
xiaoting 已提交
1680 1681 1682 1683 1684 1685 1686 1687 1688

        self.output_sizes = output_sizes
        self.kernel_sizes = kernel_sizes
        self.dilations = dilations
        self.paddings = paddings
        self.strides = strides
        self.name = name

    def forward(self, input):
1689 1690 1691 1692 1693 1694 1695 1696 1697
        return F.fold(
            input,
            output_sizes=self.output_sizes,
            kernel_sizes=self.kernel_sizes,
            strides=self.strides,
            paddings=self.paddings,
            dilations=self.dilations,
            name=self.name,
        )
X
xiaoting 已提交
1698 1699 1700

    def extra_repr(self):
        name_str = ', name={}'.format(self.name) if self.name else ''
1701 1702 1703 1704 1705 1706 1707
        return 'kernel_size={}, dilation={}, padding={}, stride={}{}'.format(
            self.kernel_sizes,
            self.dilations,
            self.paddings,
            self.strides,
            name_str,
        )