提交 2398db5e 编写于 作者: C chengduoZH

follow comments

上级 93e0609f
...@@ -21,6 +21,7 @@ from ..framework import Variable ...@@ -21,6 +21,7 @@ from ..framework import Variable
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from layer_function_generator import autodoc from layer_function_generator import autodoc
from tensor import concat from tensor import concat
import utils
__all__ = [ __all__ = [
'fc', 'fc',
...@@ -1138,8 +1139,8 @@ def sequence_conv(input, ...@@ -1138,8 +1139,8 @@ def sequence_conv(input,
def conv2d(input, def conv2d(input,
num_filters, num_filters,
filter_size, filter_size,
stride=None, stride=1,
padding=None, padding=0,
groups=None, groups=None,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
...@@ -1252,17 +1253,15 @@ def conv2d(input, ...@@ -1252,17 +1253,15 @@ def conv2d(input,
raise ValueError("num_channels must be divisible by groups.") raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels / groups num_filter_channels = num_channels / groups
if isinstance(filter_size, int): filter_size = utils.convert_to_list(filter_size, 2, 'conv2d.filter_size')
filter_size = [filter_size, filter_size] stride = utils.convert_to_list(stride, 2, 'conv2d.stride')
if isinstance(stride, int): padding = utils.convert_to_list(padding, 2, 'conv2d.padding')
stride = [stride, stride]
if isinstance(padding, int):
padding = [padding, padding]
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
input_shape = input.shape input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + list(filter_size) filter_shape = [num_filters, num_filter_channels] + filter_size
def _get_default_param_initializer(): def _get_default_param_initializer():
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
...@@ -1685,9 +1684,9 @@ def conv2d_transpose(input, ...@@ -1685,9 +1684,9 @@ def conv2d_transpose(input,
num_filters, num_filters,
output_size=None, output_size=None,
filter_size=None, filter_size=None,
padding=None, padding=0,
stride=None, stride=1,
dilation=None, dilation=1,
param_attr=None, param_attr=None,
use_cudnn=True, use_cudnn=True,
name=None): name=None):
...@@ -1783,26 +1782,12 @@ def conv2d_transpose(input, ...@@ -1783,26 +1782,12 @@ def conv2d_transpose(input,
raise TypeError("Input of conv2d_transpose must be Variable") raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1] input_channel = input.shape[1]
op_attr = dict() padding = utils.convert_to_list(padding, 2, 'conv2d_transpose.padding')
stride = utils.convert_to_list(stride, 2, 'conv2d_transpose.stride')
if isinstance(padding, int): dilation = utils.convert_to_list(dilation, 2, 'conv2d_transpose.dilation')
op_attr['paddings'] = [padding, padding]
elif padding is not None:
op_attr['paddings'] = padding
if isinstance(stride, int):
op_attr['strides'] = [stride, stride]
elif stride is not None:
op_attr['strides'] = stride
if isinstance(dilation, int):
op_attr['dilations'] = [dilation, dilation]
elif dilation is not None:
op_attr['dilations'] = dilation
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
op_attr['use_cudnn'] = use_cudnn
if filter_size is None: if filter_size is None:
if output_size is None: if output_size is None:
...@@ -1810,10 +1795,6 @@ def conv2d_transpose(input, ...@@ -1810,10 +1795,6 @@ def conv2d_transpose(input,
if isinstance(output_size, int): if isinstance(output_size, int):
output_size = [output_size, output_size] output_size = [output_size, output_size]
padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1])
dilation = op_attr.get('dilations', [1, 1])
h_in = input.shape[2] h_in = input.shape[2]
w_in = input.shape[3] w_in = input.shape[3]
...@@ -1822,11 +1803,11 @@ def conv2d_transpose(input, ...@@ -1822,11 +1803,11 @@ def conv2d_transpose(input,
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1 padding[1] - 1) / dilation[1] + 1
filter_size = [filter_size_h, filter_size_w] filter_size = [filter_size_h, filter_size_w]
else:
filter_size = utils.convert_to_list(filter_size, 2,
'conv2d_transpose.filter_size')
elif isinstance(filter_size, int): filter_shape = [input_channel, num_filters] + filter_size
filter_size = [filter_size, filter_size]
filter_shape = [input_channel, num_filters] + list(filter_size)
img_filter = helper.create_parameter( img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
...@@ -1836,7 +1817,12 @@ def conv2d_transpose(input, ...@@ -1836,7 +1817,12 @@ def conv2d_transpose(input,
inputs={'Input': [input], inputs={'Input': [input],
'Filter': [img_filter]}, 'Filter': [img_filter]},
outputs={'Output': out}, outputs={'Output': out},
attrs=op_attr) attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'use_cudnn': use_cudnn
})
return out return out
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def convert_to_list(value, n, name):
"""Converts a single integer or iterable of integers into an integer list.
Arguments:
value: The value to validate and convert. Could an int, or any iterable
of ints.
n: The size of the list to be returned.
name: The name of the argument being validated, e.g. "stride" or
"filter_size". This is only used to format error messages.
Returns:
A list of n integers.
Raises:
ValueError: If something else than an int/long or iterable thereof was
passed.
"""
if isinstance(value, int):
return [value, ] * n
else:
try:
value_list = list(value)
except TypeError:
raise ValueError("The " + name +
"'s type must be list or tuple. Received: " + str(
value))
if len(value_list) != n:
raise ValueError("The " + name + "'s length must be " + str(n) +
". Received: " + str(value))
for single_value in value_list:
try:
int(single_value)
except (ValueError, TypeError):
raise ValueError(
"The " + name + "'s must be a list or tuple of " + str(
n) + " integers. Received: " + str(value) + " "
"including element " + str(single_value) + " of type" + " "
+ str(type(single_value)))
return value_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册