From 2398db5e5ce4b33a093c5b26930c4862226dea99 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 24 Feb 2018 22:36:45 +0800 Subject: [PATCH] follow comments --- python/paddle/v2/fluid/layers/nn.py | 62 ++++++++++---------------- python/paddle/v2/fluid/layers/utils.py | 54 ++++++++++++++++++++++ 2 files changed, 78 insertions(+), 38 deletions(-) create mode 100644 python/paddle/v2/fluid/layers/utils.py diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 8623e1f038..57813af939 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -21,6 +21,7 @@ from ..framework import Variable from ..param_attr import ParamAttr from layer_function_generator import autodoc from tensor import concat +import utils __all__ = [ 'fc', @@ -1138,8 +1139,8 @@ def sequence_conv(input, def conv2d(input, num_filters, filter_size, - stride=None, - padding=None, + stride=1, + padding=0, groups=None, param_attr=None, bias_attr=None, @@ -1252,17 +1253,15 @@ def conv2d(input, raise ValueError("num_channels must be divisible by groups.") num_filter_channels = num_channels / groups - if isinstance(filter_size, int): - filter_size = [filter_size, filter_size] - if isinstance(stride, int): - stride = [stride, stride] - if isinstance(padding, int): - padding = [padding, padding] + filter_size = utils.convert_to_list(filter_size, 2, 'conv2d.filter_size') + stride = utils.convert_to_list(stride, 2, 'conv2d.stride') + padding = utils.convert_to_list(padding, 2, 'conv2d.padding') + if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") input_shape = input.shape - filter_shape = [num_filters, num_filter_channels] + list(filter_size) + filter_shape = [num_filters, num_filter_channels] + filter_size def _get_default_param_initializer(): std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 @@ -1685,9 +1684,9 @@ def conv2d_transpose(input, num_filters, output_size=None, filter_size=None, - padding=None, - stride=None, - dilation=None, + padding=0, + stride=1, + dilation=1, param_attr=None, use_cudnn=True, name=None): @@ -1783,26 +1782,12 @@ def conv2d_transpose(input, raise TypeError("Input of conv2d_transpose must be Variable") input_channel = input.shape[1] - op_attr = dict() - - if isinstance(padding, int): - op_attr['paddings'] = [padding, padding] - elif padding is not None: - op_attr['paddings'] = padding - - if isinstance(stride, int): - op_attr['strides'] = [stride, stride] - elif stride is not None: - op_attr['strides'] = stride - - if isinstance(dilation, int): - op_attr['dilations'] = [dilation, dilation] - elif dilation is not None: - op_attr['dilations'] = dilation + padding = utils.convert_to_list(padding, 2, 'conv2d_transpose.padding') + stride = utils.convert_to_list(stride, 2, 'conv2d_transpose.stride') + dilation = utils.convert_to_list(dilation, 2, 'conv2d_transpose.dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") - op_attr['use_cudnn'] = use_cudnn if filter_size is None: if output_size is None: @@ -1810,10 +1795,6 @@ def conv2d_transpose(input, if isinstance(output_size, int): output_size = [output_size, output_size] - padding = op_attr.get('paddings', [0, 0]) - stride = op_attr.get('strides', [1, 1]) - dilation = op_attr.get('dilations', [1, 1]) - h_in = input.shape[2] w_in = input.shape[3] @@ -1822,11 +1803,11 @@ def conv2d_transpose(input, filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] - 1) / dilation[1] + 1 filter_size = [filter_size_h, filter_size_w] + else: + filter_size = utils.convert_to_list(filter_size, 2, + 'conv2d_transpose.filter_size') - elif isinstance(filter_size, int): - filter_size = [filter_size, filter_size] - - filter_shape = [input_channel, num_filters] + list(filter_size) + filter_shape = [input_channel, num_filters] + filter_size img_filter = helper.create_parameter( dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) @@ -1836,7 +1817,12 @@ def conv2d_transpose(input, inputs={'Input': [input], 'Filter': [img_filter]}, outputs={'Output': out}, - attrs=op_attr) + attrs={ + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'use_cudnn': use_cudnn + }) return out diff --git a/python/paddle/v2/fluid/layers/utils.py b/python/paddle/v2/fluid/layers/utils.py new file mode 100644 index 0000000000..d04f2f86ac --- /dev/null +++ b/python/paddle/v2/fluid/layers/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def convert_to_list(value, n, name): + """Converts a single integer or iterable of integers into an integer list. + + Arguments: + value: The value to validate and convert. Could an int, or any iterable + of ints. + n: The size of the list to be returned. + name: The name of the argument being validated, e.g. "stride" or + "filter_size". This is only used to format error messages. + + Returns: + A list of n integers. + + Raises: + ValueError: If something else than an int/long or iterable thereof was + passed. + """ + if isinstance(value, int): + return [value, ] * n + else: + try: + value_list = list(value) + except TypeError: + raise ValueError("The " + name + + "'s type must be list or tuple. Received: " + str( + value)) + if len(value_list) != n: + raise ValueError("The " + name + "'s length must be " + str(n) + + ". Received: " + str(value)) + for single_value in value_list: + try: + int(single_value) + except (ValueError, TypeError): + raise ValueError( + "The " + name + "'s must be a list or tuple of " + str( + n) + " integers. Received: " + str(value) + " " + "including element " + str(single_value) + " of type" + " " + + str(type(single_value))) + return value_list -- GitLab