diff --git a/python/paddle/fluid/tests/unittests/test_conv1d_transpose_layer.py b/python/paddle/fluid/tests/unittests/test_conv1d_transpose_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..73227dd3610376d85fcfc70bb2653dfd927427fd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_conv1d_transpose_layer.py @@ -0,0 +1,229 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle import fluid, nn +import paddle.fluid.dygraph as dg +import paddle.nn.functional as F +import paddle.fluid.initializer as I +import unittest + + +class ConvTranspose1dTestCase(unittest.TestCase): + def __init__(self, + methodName='runTest', + batch_size=4, + spartial_shape=16, + in_channels=6, + out_channels=8, + filter_size=3, + output_size=None, + padding=0, + output_padding=0, + stride=1, + dilation=1, + groups=1, + no_bias=False, + data_format="NCL", + dtype="float32"): + super(ConvTranspose1dTestCase, self).__init__(methodName) + self.batch_size = batch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.spartial_shape = spartial_shape + self.filter_size = filter_size + self.output_size = output_size + + self.padding = padding + self.output_padding = output_padding + self.stride = stride + self.dilation = dilation + self.groups = groups + self.no_bias = no_bias + self.data_format = data_format + self.dtype = dtype + + def setUp(self): + + self.channel_last = False if self.data_format == "NCL" else True + input_shape = (self.batch_size, self.in_channels, + self.spartial_shape) if not self.channel_last else ( + self.batch_size, + self.spartial_shape, + self.in_channels, ) + self.input = np.random.randn(*input_shape).astype(self.dtype) + + if isinstance(self.filter_size, int): + filter_size = [self.filter_size] + else: + filter_size = self.filter_size + self.weight_shape = weight_shape = (self.in_channels, self.out_channels + // self.groups) + tuple(filter_size) + self.weight = np.random.uniform( + -1, 1, size=weight_shape).astype(self.dtype) + if not self.no_bias: + self.bias = np.random.uniform( + -1, 1, size=(self.out_channels, )).astype(self.dtype) + else: + self.bias = None + + def functional(self, place): + main = fluid.Program() + start = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, start): + input_shape = (-1, self.in_channels, + -1) if not self.channel_last else ( + -1, -1, self.in_channels) + x_var = fluid.data("input", input_shape, dtype=self.dtype) + w_var = fluid.data( + "weight", self.weight_shape, dtype=self.dtype) + b_var = fluid.data( + "bias", (self.out_channels, ), dtype=self.dtype) + y_var = F.conv_transpose1d( + x_var, + w_var, + None if self.no_bias else b_var, + output_size=self.output_size, + padding=self.padding, + output_padding=self.output_padding, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format) + feed_dict = {"input": self.input, "weight": self.weight} + if self.bias is not None: + feed_dict["bias"] = self.bias + exe = fluid.Executor(place) + exe.run(start) + y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def paddle_nn_layer(self): + x_var = paddle.to_tensor(self.input) + conv = nn.ConvTranspose1d( + self.in_channels, + self.out_channels, + self.filter_size, + padding=self.padding, + output_padding=self.output_padding, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format) + conv.weight.set_value(self.weight) + if not self.no_bias: + conv.bias.set_value(self.bias) + y_var = conv(x_var, output_size=self.output_size) + y_np = y_var.numpy() + return y_np + + def _test_equivalence(self, place): + result1 = self.functional(place) + with dg.guard(place): + result2 = self.paddle_nn_layer() + np.testing.assert_array_almost_equal(result1, result2) + + def runTest(self): + place = fluid.CPUPlace() + self._test_equivalence(place) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self._test_equivalence(place) + + +class ConvTranspose1dErrorTestCase(ConvTranspose1dTestCase): + def runTest(self): + place = fluid.CPUPlace() + with dg.guard(place): + with self.assertRaises(ValueError): + self.paddle_nn_layer() + + +def add_cases(suite): + suite.addTest(ConvTranspose1dTestCase(methodName='runTest')) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', stride=[2], no_bias=True, dilation=2)) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', + filter_size=(3), + output_size=[36], + stride=[2], + dilation=2)) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', stride=2, dilation=(2))) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', padding="valid")) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', padding='valid')) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', filter_size=1, padding=3)) + suite.addTest(ConvTranspose1dTestCase(methodName='runTest', padding=[2])) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', data_format="NLC")) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', groups=2, padding="valid")) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', + out_channels=6, + in_channels=3, + groups=3, + padding="valid")) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', + data_format="NLC", + spartial_shape=16, + output_size=18)) + suite.addTest( + ConvTranspose1dTestCase( + methodName='runTest', data_format="NLC", stride=3, + output_padding=2)) + + +def add_error_cases(suite): + suite.addTest( + ConvTranspose1dErrorTestCase( + methodName='runTest', data_format="not_valid")) + suite.addTest( + ConvTranspose1dErrorTestCase( + methodName='runTest', in_channels=5, groups=2)) + suite.addTest( + ConvTranspose1dErrorTestCase( + methodName='runTest', stride=2, output_padding=3)) + suite.addTest( + ConvTranspose1dErrorTestCase( + methodName='runTest', output_size="not_valid")) + + +def load_tests(loader, standard_tests, pattern): + suite = unittest.TestSuite() + add_cases(suite) + add_error_cases(suite) + return suite + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 131231ade67a758f26e6160e838feffdd5dfa03b..b2fe248c26c1e0733b60f1be0fb8919272e74133 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -97,6 +97,7 @@ from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .layer.conv import Conv1d #DEFINE_ALIAS from .layer.conv import Conv2d #DEFINE_ALIAS from .layer.conv import Conv3d #DEFINE_ALIAS +from .layer.conv import ConvTranspose1d #DEFINE_ALIAS from .layer.conv import ConvTranspose2d #DEFINE_ALIAS from .layer.conv import ConvTranspose3d #DEFINE_ALIAS # from .layer.conv import TreeConv #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index c1fcc230c1c8d32d2f0371a7f33d34924a7587f5..ba3d80d40b0e079748e6b3e6a7ee0d5030c02e2b 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -71,6 +71,7 @@ from .common import unfold #DEFINE_ALIAS from .common import assign #DEFINE_ALIAS from .common import interpolate #DEFINE_ALIAS from .conv import conv1d #DEFINE_ALIAS +from .conv import conv_transpose1d #DEFINE_ALIAS from .conv import conv2d #DEFINE_ALIAS from .conv import conv_transpose2d #DEFINE_ALIAS from .conv import conv3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 47bf94bdce22c2da8e7865a5c3ace3b188cdc570..f80f200c7163836252faa4b1c932178f6bab0dff 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -15,6 +15,7 @@ from __future__ import print_function __all__ = [ 'conv1d', + 'conv_transpose1d', 'conv2d', 'conv_transpose2d', 'conv3d', @@ -29,6 +30,7 @@ from ...fluid.layers import nn, utils from ...fluid.data_feeder import check_variable_and_dtype from ...fluid.param_attr import ParamAttr from ...fluid.layer_helper import LayerHelper +from .common import pad2d def _is_list_or_tuple(input): @@ -545,6 +547,260 @@ def conv2d(x, return out +def conv_transpose1d(x, + weight, + bias=None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, + output_size=None, + data_format="NCL", + name=None): + """ + The 1-D convolution transpose layer calculates the output based on the input, + filter, and dilation, stride, padding. Input(Input) and output(Output) + are in 'NCL' format or 'NLC' where N is batch size, C is the number of channels, + L is the length of the feature. The details of convolution transpose + layer, please refer to the following explanation and references + `therein `_. + If bias attribution and activation type are provided, bias is added to + the output of the convolution, and the corresponding activation function + is applied to the final result. + + For each input :math:`X`, the equation is: + + .. math:: + + Out = \sigma (W \\ast X + b) + + Where: + + * :math:`X`: Input value, a 3-D Tensor with 'NCL' format or 'NLC' format. + * :math:`W`: Filter value, a 3-D Tensor with 'MCK' format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D Tensor with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, a 3-D Tensor with data format 'NCL' or 'NLC', the shape of :math:`Out` and :math:`X` may be different. + + Example: + + - Input: + + Input shape: :math:`(N, C_{in}, L_{in})` + + Filter shape: :math:`(C_{in}, C_{out}, L_f)` + + - Output: + + Output shape: :math:`(N, C_{out}, L_{out})` + + Where + + .. math:: + + L^\prime_{out} &= (L_{in} - 1) * stride - pad_top - pad_bottom + dilation * (L_f - 1) + 1 + output_padding \\\\ + L_{out} &\in [ L^\prime_{out}, L^\prime_{out} + stride ] + + Note: + The conv1d_transpose can be seen as the backward of the conv1d. For conv1d, + when stride > 1, conv1d maps multiple input shape to the same output shape, + so for conv1d_transpose, when stride > 1, input shape maps multiple output shape. + If output_size is None, :math:`L_{out} = L^\prime_{out}`; + else, the :math:`L_{out}` of the output size must between :math:`L^\prime_{out}` + and :math:`L^\prime_{out} + stride`. conv1d_transpose can compute the kernel size automatically. + + Args: + x(Tensor): 3-D tensor with [N, C, L] or [N, L, C] format, + its data type is float32 or float64. + weight(Tensor): The convolution kernel, a Tensor with shape [C, M/g, K], + where M is the number of output channels(filters), g is the number of groups, + K is the size of the kernel. + bias(Tensor, optional): The bias, a Tensor with shape [M, ]. + stride(int|tuple|list, optional): The stride size. It means the stride in transposed convolution. + If stride is a tuple, it must contain one integer, `(stride_size)`. + Default: stride = 1. + padding(int|list|str|tuple, optional): The padding size. The padding argument effectively adds + `dilation * (kernel - 1)` amount of zero-padding on both sides of input. If `padding` is a + string, either 'VALID' or 'SAME' supported, which is the padding algorithm. + If `padding` is a tuple or list, it could be in two forms: + `[pad]` or `[pad_left, pad_right]`. Default: padding = 0. + output_padding(int|list|tuple, optional): The count of zeros to be added to tail of each dimension. + If it is a tuple, it must contain one integer. Default: 0. + groups(int, optional): The groups number of the conv1d transpose function. Inspired by + grouped convolution in Alex Krizhevsky's Deep CNN paper, in which + when group=2, the first half of the filters is only connected to the + first half of the input channels, while the second half of the + filters is only connected to the second half of the input channels. + Default: groups = 1. + dilation(int|tuple|list, optional): The dilation size. It means the spacing between the kernel points. + If dilation is a tuple, it must contain one integer, `(dilation_size)`. + Default: dilation = 1. + output_size(int|tuple|list, optional): The output image size. If output size is a + tuple, it must contain one integer, `(feature_length)`. None if use + filter_size, padding, and stride to calculate output_size. + If output_size and filter_size are specified at the same time, They + should follow the formula above. Default: None. output_size and filter_size + should not be None at the same time. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCL"`, `"NLC"`. + The default is `"NCL"`. When it is `"NCL"`, the data is stored in the order of: + `[batch_size, input_channels, input_length]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + A tensor representing the result of 1-D transpose convolution, whose + data type is the same with input. And its shape is (num_batches, channels, length) + when data_format is `"NCL"` and (num_batches, length, channels) when data_format is + `"NLC"`. + + Raises: + ValueError: If `data_format` is a string, but not "NCL" or "NLC". + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is a tuple, but the element corresponding to the input's batch size is not 0 + or the element corresponding to the input's channel is not 0. + ValueError: If `output_size` and filter_size are None at the same time. + ValueError: If `output_padding` is greater than `stride`. + ShapeError: If the input is not 3-D Tensor. + ShapeError: If the input's dimension size and filter's dimension size not equal. + ShapeError: If the dimension size of input minus the size of `stride` is not 1. + ShapeError: If the number of input channels is not equal to filter's channels. + ShapeError: If the size of `output_size` is not equal to that of `stride`. + + Examples: + .. code-block:: python + + + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + # shape: (1, 2, 4) + x=np.array([[[4, 0, 9, 7], + [8, 0, 9, 2,]]]).astype(np.float32) + # shape: (2, 1, 2) + y=np.array([[[7, 0]], + [[4, 2]]]).astype(np.float32) + x_var = paddle.to_tensor(x) + w_var = paddle.to_tensor(w) + y_var = F.conv_transpose1d(x_var, w_var) + y_np = y_var.numpy() + print y_np + + # [[[60. 16. 99. 75. 4.]]] + """ + cudnn_version = get_cudnn_version() + if cudnn_version is not None: + use_cudnn = True + else: + use_cudnn = False + + if data_format not in ['NCL', 'NLC']: + raise ValueError( + "Attr(data_format) of conv2d_transpose got wrong value: " + "received {}, but only 'NCL' or 'NLC' are supported.".format( + data_format)) + channel_last = (data_format == "NLC") + channel_dim = -1 if channel_last else 1 + + num_channels = x.shape[channel_dim] + if num_channels < 0: + raise ValueError("The channel dimmention of the input({}) " + "should be defined. Received: {}.".format( + x.shape, num_channels)) + if num_channels % groups != 0: + raise ValueError( + "the channel of input must be divisible by groups," + "received: the channel of input is {}, the shape of input is {}" + ", the groups is {}".format(num_channels, x.shape, groups)) + + # update attrs + padding, padding_algorithm = _update_padding_nd(padding, channel_last, 1) + + if len(padding) == 2: + padding = padding + [0] * 2 + elif len(padding) == 1: + padding = padding + [0] + else: + raise ValueError( + "The size of padding's dimmention should 1 or 2. But got padding={}". + format(padding)) + + stride = utils.convert_to_list(stride, 1, 'stride') + [1] + dilation = utils.convert_to_list(dilation, 1, 'dilation') + [1] + output_padding = utils.convert_to_list(output_padding, 1, + 'output_padding') + [0] + if output_padding[0] > stride[0]: + raise ValueError( + "The size of output_padding should not be greater than stride." + "But got output_padding={} and stride={}".format(output_padding[0], + stride[0])) + + if output_size is None: + output_size = [] + elif isinstance(output_size, (list, tuple, int)): + output_size = utils.convert_to_list(output_size, 1, 'output_size') + [1] + else: + raise ValueError("output_size should be int, or list, tuple of ints") + + op_type = 'conv2d_transpose' + num_filters = weight.shape[1] + if (num_channels == groups and num_filters == 1 and not use_cudnn): + op_type = 'depthwise_conv2d_transpose' + use_cudnn = False + + squeeze_axis = -2 if channel_last else -1 + conv2d_data_format = "NHWC" if channel_last else "NCHW" + + x = nn.unsqueeze(input=x, axes=[squeeze_axis]) + weight = nn.unsqueeze(input=weight, axes=[-1]) + + if in_dygraph_mode(): + attrs = ('output_size', output_size, 'strides', stride, 'paddings', + padding, 'padding_algorithm', padding_algorithm, 'dilations', + dilation, 'groups', groups, 'use_cudnn', use_cudnn, + 'data_format', conv2d_data_format) + out = getattr(core.ops, op_type)(x, weight, *attrs) + if bias is not None: + out = nn.elementwise_add(out, bias, axis=channel_dim) + else: + inputs = {'Input': [x], 'Filter': [weight]} + attrs = { + 'output_size': output_size, + 'strides': stride, + 'paddings': padding, + 'padding_algorithm': padding_algorithm, + 'dilations': dilation, + 'groups': groups, + 'use_cudnn': use_cudnn, + 'data_format': conv2d_data_format + } + check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], + 'conv2d_transpose') + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + outputs = {"Output": [out]} + helper.append_op( + type=op_type, inputs=inputs, outputs=outputs, attrs=attrs) + if bias is not None: + out = nn.elementwise_add(out, bias, axis=channel_dim) + + if output_size is None: + out = pad2d( + out, + padding=[0, output_padding, 0, 0], + data_format=conv2d_data_format, + name=name) + out = nn.squeeze(input=out, axes=[squeeze_axis]) + return out + + def conv_transpose2d(x, weight, bias=None, diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 70c1f754d91ec16a8349537e49b57230afd7fb55..2fa248450b949f5cfad086be79d19259fed4c8ee 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -61,6 +61,7 @@ from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .conv import Conv1d #DEFINE_ALIAS from .conv import Conv2d #DEFINE_ALIAS from .conv import Conv3d #DEFINE_ALIAS +from .conv import ConvTranspose1d #DEFINE_ALIAS from .conv import ConvTranspose2d #DEFINE_ALIAS from .conv import ConvTranspose3d #DEFINE_ALIAS # from .conv import TreeConv #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index 6b9b792d5fce58a9b99add7e67e0492531fe4704..7d0e59fb7575c9d15d28e88a462aed4ddba47fb9 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -18,6 +18,7 @@ __all__ = [ 'Conv1d', 'Conv2d', 'Conv3d', + 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', ] @@ -374,7 +375,6 @@ class Conv2d(_ConvNd): Examples: .. code-block:: python import numpy as np - import paddle import paddle.nn as nn x = np.random.uniform(-1, 1, (2, 4, 8, 8)).astype('float32') @@ -443,6 +443,191 @@ class Conv2d(_ConvNd): return out +class ConvTranspose1d(layers.Layer): + """ + This interface is used to construct a callable object of the ``ConvTranspose1d`` class. + For more details, refer to code examples. + The 1-D convolution transpose layer calculates the output based on the input, + filter, and dilation, stride, padding. Input(Input) and output(Output) + are in 'NCL' format or 'NLC' where N is batch size, C is the number of channels, + L is the length of the feature. The details of convolution transpose + layer, please refer to the following explanation and references + `therein `_. + If bias attribution and activation type are provided, bias is added to + the output of the convolution, and the corresponding activation function + is applied to the final result. + + For each input :math:`X`, the equation is: + + .. math:: + + Out = \sigma (W \\ast X + b) + + Where: + + * :math:`X`: Input value, a 3-D Tensor with 'NCL' format or 'NLC' format. + * :math:`W`: Kernel value, a 3-D Tensor with 'MCK' format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D Tensor with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, a 3-D Tensor with data format 'NCL' of 'NLC', the shape of :math:`Out` and :math:`X` may be different. + + Example: + + - Input: + + Input shape: :math:`(N, C_{in}, L_{in})` + + Filter shape: :math:`(C_{in}, C_{out}, L_f)` + + - Output: + + Output shape: :math:`(N, C_{out}, L_{out})` + + Where + + .. math:: + + L^\prime_{out} &= (L_{in} - 1) * stride - pad_top - pad_bottom + dilation * (L_f - 1) + 1 \\\\ + L_{out} &\in [ L^\prime_{out}, L^\prime_{out} + stride ] + + Note: + The conv1d_transpose can be seen as the backward of the conv1d. For conv1d, + when stride > 1, conv1d maps multiple input shape to the same output shape, + so for conv1d_transpose, when stride > 1, input shape maps multiple output shape. + If output_size is None, :math:`L_{out} = L^\prime_{out}`; + else, the :math:`L_{out}` of the output size must between :math:`L^\prime_{out}` + and :math:`L^\prime_{out} + stride`. conv1d_transpose can compute the kernel size automatically. + + Args: + in_channels(int): The number of channels in the input image. + out_channels(int): The number of the filter. It is as same as the output + feature map. + kernel_size(int|tuple|list, optional): The filter size. If kernel_size is a tuple, + it must contain one integers, (kernel_size). None if + use output size to calculate kernel_size. Default: None. kernel_size and + output_size should not be None at the same time. + stride(int|tuple|list, optional): The stride size. It means the stride in transposed convolution. + If stride is a tuple, it must contain one integer, (stride_size). + Default: stride = 1. + padding(int|list|str|tuple, optional): The padding size. The padding argument effectively adds + `dilation * (kernel - 1)` amount of zero-padding on both sides of input. If `padding` is a + string, either 'VALID' or 'SAME' supported, which is the padding algorithm. + If `padding` is a tuple or list, it could be in two forms: + `[pad]` or `[pad_left, pad_right]`. Default: padding = 0. + output_padding(int|list|tuple, optional): The count of zeros to be added to tail of each dimension. + If it is a tuple, it must contain one integer. Default: 0. + groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by + grouped convolution in Alex Krizhevsky's Deep CNN paper, in which + when group=2, the first half of the filters is only connected to the + first half of the input channels, while the second half of the + filters is only connected to the second half of the input channels. + Default: groups = 1. + bias(bool, optional): Whether to use bias. Default: True. + dilation(int|tuple|list, optional): The dilation size. It means the spacing between the kernel points. + If dilation is a tuple, it must contain one integer, (dilation_size). + Default: dilation = 1. + weight_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights + of conv1d_transpose. If it is set to None or one attribute of ParamAttr, conv1d_transpose + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr|bool, optional): The parameter attribute for the bias of conv1d_transpose. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv1d_transpose + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + + Attribute: + **weight** (Parameter): the learnable weights of filters of this layer. + **bias** (Parameter or None): the learnable bias of this layer. + + Shape: + - x(Tensor): 3-D tensor with shape (batch, in_channels, length) when data_format is + "NCL" or shape (batch, length, in_channels) when data_format is "NLC". + - output_size(int|tuple|list, optional): The output image size. If output size is a + tuple, it must contain one integer, (feature_length). None if use + kernel_size, padding, output_padding and stride to calculate output_size. + If output_size and kernel_size are specified at the same time, They + should follow the formula above. Default: None. output_size and kernel_size + should not be None at the same time. + - output(Tensor): 3-D tensor with same shape as input x. + + Examples: + .. code-block:: python + + import paddle + from paddle.nn import ConvTranspose1d + import numpy as np + + paddle.disable_static() + # shape: (1, 2, 4) + x=np.array([[[4, 0, 9, 7], + [8, 0, 9, 2]]]).astype(np.float32) + # shape: (2, 1, 2) + y=np.array([[[7, 0]], + [[4, 2]]]).astype(np.float32) + x_t = paddle.to_tensor(x) + conv = ConvTranspose1d(2, 1, 2) + conv.weight.set_value(y) + y_t = conv(x_t) + y_np = y_t.numpy() + print y_np + + # [[[60. 16. 99. 75. 4.]]] + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + weight_attr=None, + bias_attr=None, + data_format="NCL"): + super(ConvTranspose1d, self).__init__() + assert weight_attr is not False, "param_attr should not be False in ConvTranspose1d." + self._param_attr = weight_attr + self._bias_attr = bias_attr + self._groups = groups + self._in_channels = in_channels + self._out_channels = out_channels + self._output_padding = output_padding + self._data_format = data_format + self._bias = bias + + self._stride = utils.convert_to_list(stride, 1, 'stride') + self._dilation = utils.convert_to_list(dilation, 1, 'dilation') + self._kernel_size = utils.convert_to_list(kernel_size, 1, 'kernel_size') + self._padding = padding + + filter_shape = [self._in_channels, out_channels // groups + ] + self._kernel_size + self.weight = self.create_parameter( + shape=filter_shape, attr=self._param_attr) + self.bias = self.create_parameter( + attr=self._bias_attr, shape=[self._out_channels], + is_bias=True) if self._bias else None + + def forward(self, x, output_size=None): + out = F.conv_transpose1d( + x, + self.weight, + bias=self.bias, + output_size=output_size, + output_padding=self._output_padding, + padding=self._padding, + stride=self._stride, + dilation=self._dilation, + groups=self._groups, + data_format=self._data_format) + return out + + class ConvTranspose2d(_ConvNd): """ This interface is used to construct a callable object of the ``ConvTranspose2d`` class.