diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3453dd945d558a93a854f99209a6ea8055875d84..ead7041b7b20c7036bbea3da544f3b422c9f31fa 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -21,6 +21,7 @@ from ..framework import Variable from ..param_attr import ParamAttr from layer_function_generator import autodoc from tensor import concat +import utils __all__ = [ 'fc', @@ -1138,8 +1139,8 @@ def sequence_conv(input, def conv2d(input, num_filters, filter_size, - stride=None, - padding=None, + stride=1, + padding=0, groups=None, param_attr=None, bias_attr=None, @@ -1252,12 +1253,10 @@ def conv2d(input, raise ValueError("num_channels must be divisible by groups.") num_filter_channels = num_channels / groups - if isinstance(filter_size, int): - filter_size = [filter_size, filter_size] - if isinstance(stride, int): - stride = [stride, stride] - if isinstance(padding, int): - padding = [padding, padding] + filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') + stride = utils.convert_to_list(stride, 2, 'stride') + padding = utils.convert_to_list(padding, 2, 'padding') + if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1432,10 +1431,10 @@ def sequence_last_step(input): def pool2d(input, - pool_size, - pool_type, - pool_stride=None, - pool_padding=None, + pool_size=-1, + pool_type="max", + pool_stride=1, + pool_padding=0, global_pooling=False, use_cudnn=True, name=None): @@ -1443,20 +1442,20 @@ def pool2d(input, This function adds the operator for pooling in 2 dimensions, using the pooling configurations mentioned in input parameters. """ - if pool_padding is None: - pool_padding = [0, 0] - if pool_stride is None: - pool_stride = [1, 1] if pool_type not in ["max", "avg"]: raise ValueError( "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", str(pool_type)) - if isinstance(pool_size, int): - pool_size = [pool_size, pool_size] - if isinstance(pool_stride, int): - pool_stride = [pool_stride, pool_stride] - if isinstance(pool_padding, int): - pool_padding = [pool_padding, pool_padding] + + if global_pooling is False and pool_size == -1: + raise ValueError( + "When the global_pooling is False, pool_size must be passed " + "and be a valid value. Received pool_size: " + str(pool_size)) + + pool_size = utils.convert_to_list(pool_size, 2, 'pool_size') + pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding') + pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride') + if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1685,9 +1684,9 @@ def conv2d_transpose(input, num_filters, output_size=None, filter_size=None, - padding=None, - stride=None, - dilation=None, + padding=0, + stride=1, + dilation=1, param_attr=None, use_cudnn=True, name=None): @@ -1783,26 +1782,12 @@ def conv2d_transpose(input, raise TypeError("Input of conv2d_transpose must be Variable") input_channel = input.shape[1] - op_attr = dict() - - if isinstance(padding, int): - op_attr['paddings'] = [padding, padding] - elif padding is not None: - op_attr['paddings'] = padding - - if isinstance(stride, int): - op_attr['strides'] = [stride, stride] - elif stride is not None: - op_attr['strides'] = stride - - if isinstance(dilation, int): - op_attr['dilations'] = [dilation, dilation] - elif dilation is not None: - op_attr['dilations'] = dilation + padding = utils.convert_to_list(padding, 2, 'padding') + stride = utils.convert_to_list(stride, 2, 'stride') + dilation = utils.convert_to_list(dilation, 2, 'dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") - op_attr['use_cudnn'] = use_cudnn if filter_size is None: if output_size is None: @@ -1810,10 +1795,6 @@ def conv2d_transpose(input, if isinstance(output_size, int): output_size = [output_size, output_size] - padding = op_attr.get('paddings', [0, 0]) - stride = op_attr.get('strides', [1, 1]) - dilation = op_attr.get('dilations', [1, 1]) - h_in = input.shape[2] w_in = input.shape[3] @@ -1822,9 +1803,9 @@ def conv2d_transpose(input, filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] - 1) / dilation[1] + 1 filter_size = [filter_size_h, filter_size_w] - - elif isinstance(filter_size, int): - filter_size = [filter_size, filter_size] + else: + filter_size = utils.convert_to_list(filter_size, 2, + 'conv2d_transpose.filter_size') filter_shape = [input_channel, num_filters] + filter_size img_filter = helper.create_parameter( @@ -1836,7 +1817,12 @@ def conv2d_transpose(input, inputs={'Input': [input], 'Filter': [img_filter]}, outputs={'Output': out}, - attrs=op_attr) + attrs={ + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'use_cudnn': use_cudnn + }) return out diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..49ec3088831dff415e042e1b0a632f63106eb07b --- /dev/null +++ b/python/paddle/fluid/layers/utils.py @@ -0,0 +1,59 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + + +def convert_to_list(value, n, name, dtype=np.int): + """ + Converts a single numerical type or iterable of numerical + types into an numerical type list. + + Arguments: + value: The value to validate and convert. Could an int, or any iterable + of ints. + n: The size of the list to be returned. + name: The name of the argument being validated, e.g. "stride" or + "filter_size". This is only used to format error messages. + dtype: the numerical type of the element of the list to be returned. + + Returns: + A list of n dtypes. + + Raises: + ValueError: If something else than an int/long or iterable thereof was + passed. + """ + if isinstance(value, dtype): + return [value, ] * n + else: + try: + value_list = list(value) + except TypeError: + raise ValueError("The " + name + + "'s type must be list or tuple. Received: " + str( + value)) + if len(value_list) != n: + raise ValueError("The " + name + "'s length must be " + str(n) + + ". Received: " + str(value)) + for single_value in value_list: + try: + dtype(single_value) + except (ValueError, TypeError): + raise ValueError( + "The " + name + "'s type must be a list or tuple of " + str( + n) + " " + str(dtype) + " . Received: " + str( + value) + " " + "including element " + str(single_value) + " of type" + " " + + str(type(single_value))) + return value_list