From cf6238fbd96dbc50f1eb81fa486ab83be5b664a0 Mon Sep 17 00:00:00 2001 From: lujun Date: Sun, 31 Mar 2019 17:08:37 +0800 Subject: [PATCH] fix merge for move dir, fix utest error, test=develop --- python/paddle/fluid/dygraph/nn.py | 133 ++++++++++++++++-- .../fluid/tests/unittests/test_layers.py | 21 +-- 2 files changed, 130 insertions(+), 24 deletions(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index de2c2268bfd..f6d613c779d 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -139,9 +139,107 @@ class Conv2D(layers.Layer): class Conv3D(layers.Layer): + """ + **Convlution3D Layer** + + The convolution3D layer calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input(Input) and + Output(Output) are in NCDHW format. Where N is batch size C is the number of + channels, D is the depth of the feature, H is the height of the feature, + and W is the width of the feature. Convlution3D is similar with Convlution2D + but adds one dimension(depth). If bias attribution and activation type are + provided, bias is added to the output of the convolution, and the + corresponding activation function is applied to the final result. + + For each input :math:`X`, the equation is: + + .. math:: + + Out = \sigma (W \\ast X + b) + + In the above equation: + + * :math:`X`: Input value, a tensor with NCDHW format. + * :math:`W`: Filter value, a tensor with MCDHW format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Example: + + - Input: + + Input shape: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` + + Filter shape: :math:`(C_{out}, C_{in}, D_f, H_f, W_f)` + + - Output: + Output shape: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` + + Where + + .. math:: + + D_{out}&= \\frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{strides[0]} + 1 \\\\ + H_{out}&= \\frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{strides[1]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1 + + Args: + input (Variable): The input image with [N, C, D, H, W] format. + num_filters(int): The number of filter. It is as same as the output + image channel. + filter_size (int|tuple|None): The filter size. If filter_size is a tuple, + it must contain three integers, (filter_size_D, filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + stride (int|tuple): The stride size. If stride is a tuple, it must + contain three integers, (stride_D, stride_H, stride_W). Otherwise, the + stride_D = stride_H = stride_W = stride. Default: stride = 1. + padding (int|tuple): The padding size. If padding is a tuple, it must + contain three integers, (padding_D, padding_H, padding_W). Otherwise, the + padding_D = padding_H = padding_W = padding. Default: padding = 0. + dilation (int|tuple): The dilation size. If dilation is a tuple, it must + contain three integers, (dilation_D, dilation_H, dilation_W). Otherwise, the + dilation_D = dilation_H = dilation_W = dilation. Default: dilation = 1. + groups (int): The groups number of the Conv3d Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1 + param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + of conv3d. If it is set to None or one attribute of ParamAttr, conv3d + will create ParamAttr as param_attr. If it is set to None, the parameter + is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is + :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv3d. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv3d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True + act (str): Activation type, if it is set to None, activation is not appended. + Default: None. + name (str|None): A name for this layer(optional). If set None, the layer + will be named automatically. Default: None. + + Returns: + Variable: The tensor variable storing the convolution and \ + non-linearity activation result. + + Raises: + ValueError: If the shapes of input, filter_size, stride, padding and + groups mismatch. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[3, 12, 32, 32], dtype='float32') + conv3d = fluid.layers.conv3d(input=data, num_filters=2, filter_size=3, act="relu") + """ + def __init__(self, name_scope, - num_channels, num_filters, filter_size, stride=1, @@ -151,31 +249,36 @@ class Conv3D(layers.Layer): param_attr=None, bias_attr=None, use_cudnn=True, - act=None, - dtype=core.VarDesc.VarType.FP32): + act=None): assert param_attr is not False, "param_attr should not be False here." super(Conv3D, self).__init__(name_scope) self._groups = groups self._stride = utils.convert_to_list(stride, 3, 'stride') self._padding = utils.convert_to_list(padding, 3, 'padding') - self._dilation = utils.convert_to_list(dilation, 4, 'dilation') + self._dilation = utils.convert_to_list(dilation, 3, 'dilation') self._act = act if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") self._use_cudnn = use_cudnn - self._l_type = 'conv3d' - self._dtype = dtype + self._filter_size = filter_size + self._num_filters = num_filters + self._param_attr = param_attr + self._bias_attr = bias_attr - if groups is None: + def _build_once(self, input): + num_channels = input.shape[1] + self._dtype = self._helper.input_dtype(input) + + if self._groups is None: num_filter_channels = num_channels else: - if num_channels % groups != 0: + if num_channels % self._groups != 0: raise ValueError("num_channels must be divisible by groups.") - num_filter_channels = num_channels // groups + num_filter_channels = num_channels // self._groups - filter_size = utils.convert_to_list(filter_size, 3, 'filter_size') + filter_size = utils.convert_to_list(self._filter_size, 3, 'filter_size') - filter_shape = [num_filters, num_filter_channels] + filter_size + filter_shape = [self._num_filters, num_filter_channels] + filter_size def _get_default_param_initializer(): filter_elem_num = filter_size[0] * filter_size[1] * filter_size[ @@ -184,14 +287,14 @@ class Conv3D(layers.Layer): return Normal(0.0, std, 0) self._filter_param = self.create_parameter( - attr=param_attr, + attr=self._param_attr, shape=filter_shape, dtype=self._dtype, default_initializer=_get_default_param_initializer()) self._bias_param = self.create_parameter( - attr=bias_attr, - shape=[num_filters], + attr=self._bias_attr, + shape=[self._num_filters], dtype=self._dtype, is_bias=True) @@ -200,7 +303,7 @@ class Conv3D(layers.Layer): dtype=self._dtype) self._helper.append_op( - type=self._l_type, + type='conv3d', inputs={ 'Input': input, 'Filter': self._filter_param, diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 6873bec2f9d..266ecb88b42 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -564,8 +564,7 @@ class TestLayer(LayerTest): with self.static_graph(): images = layers.data( name='pixel', shape=[3, 6, 6, 6], dtype='float32') - ret = layers.conv3d( - input=images, num_filters=3, filter_size=[2, 2, 2]) + ret = layers.conv3d(input=images, num_filters=3, filter_size=2) static_ret = self.get_static_graph_result( feed={'pixel': np.ones( [2, 3, 6, 6, 6], dtype='float32')}, @@ -574,8 +573,7 @@ class TestLayer(LayerTest): with self.static_graph(): images = layers.data( name='pixel', shape=[3, 6, 6, 6], dtype='float32') - conv3d = nn.Conv3D( - 'conv3d', num_channels=3, num_filters=3, filter_size=[2, 2, 2]) + conv3d = nn.Conv3D('conv3d', num_filters=3, filter_size=2) ret = conv3d(images) static_ret2 = self.get_static_graph_result( feed={'pixel': np.ones( @@ -584,8 +582,7 @@ class TestLayer(LayerTest): with self.dynamic_graph(): images = np.ones([2, 3, 6, 6, 6], dtype='float32') - conv3d = nn.Conv3D( - 'conv3d', num_channels=3, num_filters=3, filter_size=[2, 2, 2]) + conv3d = nn.Conv3D('conv3d', num_filters=3, filter_size=2) dy_ret = conv3d(base.to_variable(images)) self.assertTrue(np.allclose(static_ret, dy_ret._numpy())) @@ -814,19 +811,25 @@ class TestLayer(LayerTest): with self.static_graph(): img = layers.data(name='pixel', shape=[3, 2, 2, 2], dtype='float32') out = layers.conv3d_transpose( - input=img, num_filters=12, output_size=[14, 14, 14]) + input=img, num_filters=12, filter_size=12, use_cudnn=False) static_rlt = self.get_static_graph_result( feed={'pixel': input_array}, fetch_list=[out])[0] with self.static_graph(): img = layers.data(name='pixel', shape=[3, 2, 2, 2], dtype='float32') conv3d_transpose = nn.Conv3DTranspose( - 'Conv3DTranspose', num_filters=12, output_size=[14, 14, 14]) + 'Conv3DTranspose', + num_filters=12, + filter_size=12, + use_cudnn=False) out = conv3d_transpose(img) static_rlt2 = self.get_static_graph_result( feed={'pixel': input_array}, fetch_list=[out])[0] with self.dynamic_graph(): conv3d_transpose = nn.Conv3DTranspose( - 'Conv3DTranspose', num_filters=12, output_size=[14, 14, 14]) + 'Conv3DTranspose', + num_filters=12, + filter_size=12, + use_cudnn=False) dy_rlt = conv3d_transpose(base.to_variable(input_array)) self.assertTrue(np.allclose(static_rlt2, static_rlt)) self.assertTrue(np.allclose(dy_rlt._numpy(), static_rlt)) -- GitLab