提交 5f48421c 编写于 作者: C chengduoZH

fix conv2d_transpose API (Add dilation)

上级 d93bbf1b
...@@ -1537,6 +1537,7 @@ def conv2d_transpose(input, ...@@ -1537,6 +1537,7 @@ def conv2d_transpose(input,
filter_size=None, filter_size=None,
padding=None, padding=None,
stride=None, stride=None,
dilation=None,
param_attr=None, param_attr=None,
main_program=None, main_program=None,
startup_program=None): startup_program=None):
...@@ -1562,6 +1563,9 @@ def conv2d_transpose(input, ...@@ -1562,6 +1563,9 @@ def conv2d_transpose(input,
stride(int|tuple): The stride size. If stride is a tuple, it must stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. stride_H = stride_W = stride.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation.
param_attr: Parameter Attribute. param_attr: Parameter Attribute.
main_program(Program): the main program main_program(Program): the main program
startup_program(Program): the startup program startup_program(Program): the startup program
...@@ -1586,6 +1590,11 @@ def conv2d_transpose(input, ...@@ -1586,6 +1590,11 @@ def conv2d_transpose(input,
elif stride is not None: elif stride is not None:
op_attr['strides'] = stride op_attr['strides'] = stride
if isinstance(dilation, int):
op_attr['dilations'] = dilation
elif stride is not None:
op_attr['dilations'] = dilation
if filter_size is None: if filter_size is None:
if output_size is None: if output_size is None:
raise ValueError("output_size must be set when filter_size is None") raise ValueError("output_size must be set when filter_size is None")
...@@ -1594,11 +1603,14 @@ def conv2d_transpose(input, ...@@ -1594,11 +1603,14 @@ def conv2d_transpose(input,
padding = op_attr.get('paddings', [0, 0]) padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1]) stride = op_attr.get('strides', [1, 1])
dilation = op_attr.get('dilations', [1, 1])
h_in = input.shape[2] h_in = input.shape[2]
w_in = input.shape[3] w_in = input.shape[3]
filter_size_h = output_size[0] - (h_in - 1) * stride[0] + 2 * padding[0] filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 *
filter_size_w = output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] padding[0] - 1) / dilation[0] + 1
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1
filter_size = [filter_size_h, filter_size_w] filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int): elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size] filter_size = [filter_size, filter_size]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册