From c8ed768ccc9adf6156a8d93964805f0e3679000b Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Sat, 24 Feb 2018 23:37:27 +0800 Subject: [PATCH] refine pool2d --- python/paddle/v2/fluid/layers/nn.py | 31 ++++++++++++----------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 57813af939..b8224f5604 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1253,9 +1253,9 @@ def conv2d(input, raise ValueError("num_channels must be divisible by groups.") num_filter_channels = num_channels / groups - filter_size = utils.convert_to_list(filter_size, 2, 'conv2d.filter_size') - stride = utils.convert_to_list(stride, 2, 'conv2d.stride') - padding = utils.convert_to_list(padding, 2, 'conv2d.padding') + filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') + stride = utils.convert_to_list(stride, 2, 'stride') + padding = utils.convert_to_list(padding, 2, 'padding') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1433,8 +1433,8 @@ def sequence_last_step(input): def pool2d(input, pool_size, pool_type, - pool_stride=None, - pool_padding=None, + pool_stride=1, + pool_padding=0, global_pooling=False, use_cudnn=True, name=None): @@ -1442,20 +1442,15 @@ def pool2d(input, This function adds the operator for pooling in 2 dimensions, using the pooling configurations mentioned in input parameters. """ - if pool_padding is None: - pool_padding = [0, 0] - if pool_stride is None: - pool_stride = [1, 1] if pool_type not in ["max", "avg"]: raise ValueError( "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", str(pool_type)) - if isinstance(pool_size, int): - pool_size = [pool_size, pool_size] - if isinstance(pool_stride, int): - pool_stride = [pool_stride, pool_stride] - if isinstance(pool_padding, int): - pool_padding = [pool_padding, pool_padding] + + pool_size = utils.convert_to_list(pool_size, 2, 'pool_size') + pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding') + pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride') + if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") @@ -1782,9 +1777,9 @@ def conv2d_transpose(input, raise TypeError("Input of conv2d_transpose must be Variable") input_channel = input.shape[1] - padding = utils.convert_to_list(padding, 2, 'conv2d_transpose.padding') - stride = utils.convert_to_list(stride, 2, 'conv2d_transpose.stride') - dilation = utils.convert_to_list(dilation, 2, 'conv2d_transpose.dilation') + padding = utils.convert_to_list(padding, 2, 'padding') + stride = utils.convert_to_list(stride, 2, 'stride') + dilation = utils.convert_to_list(dilation, 2, 'dilation') if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") -- GitLab