From 5f48421cc3718f3af2c8b90cf206089f1702592d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 8 Dec 2017 20:03:31 +0800 Subject: [PATCH] fix conv2d_transpose API (Add dilation) --- python/paddle/v2/fluid/layers.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index 99d0ac4a1..7c1514efa 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -1537,6 +1537,7 @@ def conv2d_transpose(input, filter_size=None, padding=None, stride=None, + dilation=None, param_attr=None, main_program=None, startup_program=None): @@ -1562,6 +1563,9 @@ def conv2d_transpose(input, stride(int|tuple): The stride size. If stride is a tuple, it must contain two integers, (stride_H, stride_W). Otherwise, the stride_H = stride_W = stride. + dilation(int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. param_attr: Parameter Attribute. main_program(Program): the main program startup_program(Program): the startup program @@ -1586,6 +1590,11 @@ def conv2d_transpose(input, elif stride is not None: op_attr['strides'] = stride + if isinstance(dilation, int): + op_attr['dilations'] = dilation + elif stride is not None: + op_attr['dilations'] = dilation + if filter_size is None: if output_size is None: raise ValueError("output_size must be set when filter_size is None") @@ -1594,11 +1603,14 @@ def conv2d_transpose(input, padding = op_attr.get('paddings', [0, 0]) stride = op_attr.get('strides', [1, 1]) + dilation = op_attr.get('dilations', [1, 1]) h_in = input.shape[2] w_in = input.shape[3] - filter_size_h = output_size[0] - (h_in - 1) * stride[0] + 2 * padding[0] - filter_size_w = output_size[1] - (w_in - 1) * stride[1] + 2 * padding[1] + filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 * + padding[0] - 1) / dilation[0] + 1 + filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * + padding[1] - 1) / dilation[1] + 1 filter_size = [filter_size_h, filter_size_w] elif isinstance(filter_size, int): filter_size = [filter_size, filter_size] -- GitLab