未验证 提交 087d8e7f 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #8551 from chengduoZH/fixbug/conv2d_python

Fix the bug of conv2d 
......@@ -21,6 +21,7 @@ from ..framework import Variable
from ..param_attr import ParamAttr
from layer_function_generator import autodoc
from tensor import concat
import utils
__all__ = [
'fc',
......@@ -1138,8 +1139,8 @@ def sequence_conv(input,
def conv2d(input,
num_filters,
filter_size,
stride=None,
padding=None,
stride=1,
padding=0,
groups=None,
param_attr=None,
bias_attr=None,
......@@ -1252,12 +1253,10 @@ def conv2d(input,
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups
if isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
if isinstance(stride, int):
stride = [stride, stride]
if isinstance(padding, int):
padding = [padding, padding]
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
......@@ -1432,10 +1431,10 @@ def sequence_last_step(input):
def pool2d(input,
pool_size,
pool_type,
pool_stride=None,
pool_padding=None,
pool_size=-1,
pool_type="max",
pool_stride=1,
pool_padding=0,
global_pooling=False,
use_cudnn=True,
name=None):
......@@ -1443,20 +1442,20 @@ def pool2d(input,
This function adds the operator for pooling in 2 dimensions, using the
pooling configurations mentioned in input parameters.
"""
if pool_padding is None:
pool_padding = [0, 0]
if pool_stride is None:
pool_stride = [1, 1]
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if isinstance(pool_size, int):
pool_size = [pool_size, pool_size]
if isinstance(pool_stride, int):
pool_stride = [pool_stride, pool_stride]
if isinstance(pool_padding, int):
pool_padding = [pool_padding, pool_padding]
if global_pooling is False and pool_size == -1:
raise ValueError(
"When the global_pooling is False, pool_size must be passed "
"and be a valid value. Received pool_size: " + str(pool_size))
pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding')
pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
......@@ -1685,9 +1684,9 @@ def conv2d_transpose(input,
num_filters,
output_size=None,
filter_size=None,
padding=None,
stride=None,
dilation=None,
padding=0,
stride=1,
dilation=1,
param_attr=None,
use_cudnn=True,
name=None):
......@@ -1783,26 +1782,12 @@ def conv2d_transpose(input,
raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1]
op_attr = dict()
if isinstance(padding, int):
op_attr['paddings'] = [padding, padding]
elif padding is not None:
op_attr['paddings'] = padding
if isinstance(stride, int):
op_attr['strides'] = [stride, stride]
elif stride is not None:
op_attr['strides'] = stride
if isinstance(dilation, int):
op_attr['dilations'] = [dilation, dilation]
elif dilation is not None:
op_attr['dilations'] = dilation
padding = utils.convert_to_list(padding, 2, 'padding')
stride = utils.convert_to_list(stride, 2, 'stride')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
op_attr['use_cudnn'] = use_cudnn
if filter_size is None:
if output_size is None:
......@@ -1810,10 +1795,6 @@ def conv2d_transpose(input,
if isinstance(output_size, int):
output_size = [output_size, output_size]
padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1])
dilation = op_attr.get('dilations', [1, 1])
h_in = input.shape[2]
w_in = input.shape[3]
......@@ -1822,9 +1803,9 @@ def conv2d_transpose(input,
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1
filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
else:
filter_size = utils.convert_to_list(filter_size, 2,
'conv2d_transpose.filter_size')
filter_shape = [input_channel, num_filters] + filter_size
img_filter = helper.create_parameter(
......@@ -1836,7 +1817,12 @@ def conv2d_transpose(input,
inputs={'Input': [input],
'Filter': [img_filter]},
outputs={'Output': out},
attrs=op_attr)
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'use_cudnn': use_cudnn
})
return out
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def convert_to_list(value, n, name, dtype=np.int):
"""
Converts a single numerical type or iterable of numerical
types into an numerical type list.
Arguments:
value: The value to validate and convert. Could an int, or any iterable
of ints.
n: The size of the list to be returned.
name: The name of the argument being validated, e.g. "stride" or
"filter_size". This is only used to format error messages.
dtype: the numerical type of the element of the list to be returned.
Returns:
A list of n dtypes.
Raises:
ValueError: If something else than an int/long or iterable thereof was
passed.
"""
if isinstance(value, dtype):
return [value, ] * n
else:
try:
value_list = list(value)
except TypeError:
raise ValueError("The " + name +
"'s type must be list or tuple. Received: " + str(
value))
if len(value_list) != n:
raise ValueError("The " + name + "'s length must be " + str(n) +
". Received: " + str(value))
for single_value in value_list:
try:
dtype(single_value)
except (ValueError, TypeError):
raise ValueError(
"The " + name + "'s type must be a list or tuple of " + str(
n) + " " + str(dtype) + " . Received: " + str(
value) + " "
"including element " + str(single_value) + " of type" + " "
+ str(type(single_value)))
return value_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册