diff --git a/paddle/operators/conv_transpose_cudnn_op.cc b/paddle/operators/conv_transpose_cudnn_op.cc index 2348bed4ff1c658e2b5614a3bdd94dfb554ad705..8980ff91f5d7c585d9ce0ce62cfb90f47ea86ec6 100644 --- a/paddle/operators/conv_transpose_cudnn_op.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cc @@ -21,8 +21,6 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { public: CudnnConv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker) : Conv2DTransposeOpMaker(proto, op_checker) { - AddAttr>("dilations", "dilations of convolution operator.") - .SetDefault({1, 1}); AddAttr("workspace_size_MB", "workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " @@ -37,8 +35,6 @@ class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker { public: CudnnConv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker) : Conv3DTransposeOpMaker(proto, op_checker) { - AddAttr>("dilations", "dilations of convolution operator.") - .SetDefault({1, 1, 1}); AddAttr("workspace_size_MB", "workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index cae0e2ca2b472818463e109cb700da7f62180926..5e24fc4b2c8b479caac417c957033f7552e1c3f0 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -29,6 +29,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { auto filter_dims = ctx->GetInputDim("Filter"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector dilations = ctx->Attrs().Get>("dilations"); PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, "ConvTransposeOp intput should be 4-D or 5-D tensor."); @@ -41,14 +42,18 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), "ConvTransposeOp paddings dimension and strides " "dimension should be the same."); + PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(), + "ConvTransposeOp paddings dimension and dilations " + "dimension should be the same."); PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], "In ConvTransposeOp, The input channel should be the same " "as the number of filters."); std::vector output_shape({in_dims[0], filter_dims[1]}); for (size_t i = 0; i < strides.size(); ++i) { + auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] + - filter_dims[i + 2]); + filter_extent); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -73,6 +78,12 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto, AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator. " "The format of output tensor is also NCHW."); + + AddAttr>("dilations", + "(vector default:{1, 1}), the " + "dilations(h_dilation, w_dilation) of convolution " + "transpose operator.") + .SetDefault({1, 1}); AddAttr>( "strides", "(vector default:{1, 1}), the strides(h_stride, w_stride) of " @@ -87,7 +98,7 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto, Convolution2D Transpose Operator. The convolution transpose operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the +and dilations, strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the number of channels, H is the height of the feature, and W is the width of the feature. @@ -136,6 +147,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto, "Where N is batch size, C is " "the number of channels, D is the depth of the feature, H is the " "height of the feature, and W is the width of the feature."); + + AddAttr>( + "dilations", + "(vector default:{1, 1, 1}), the " + "dilations(d_dilation,h_dilation, w_dilation) of convolution " + "transpose operator.") + .SetDefault({1, 1, 1}); AddAttr>("strides", "(vector default:{1, 1, 1}), the " "strides{d_stride, h_stride, w_stride} of " @@ -149,7 +167,7 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto, Convolution3D Transpose Operator. The convolution transpose operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the +and dilations, strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the number of channels, D is the depth of the feature, H is the height of the feature, diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index e81651f417a5b96a1f25ae9ca9228d332a16bc6d..4c8f8a80672788e8b2919e500d3627adec1ad035 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -61,6 +61,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); // groups will alway be disabled in conv2dtranspose. const int batch_size = static_cast(input->dims()[0]); @@ -113,7 +114,6 @@ class GemmConvTransposeKernel : public framework::OpKernel { math::Col2ImFunctor col2im; math::Col2VolFunctor col2vol; - std::vector dilations({1, 1, 1}); // convolution transpose: gemm + col2im or col2vol (similar to conv-backward // on input) @@ -165,6 +165,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -219,7 +220,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { math::Im2ColFunctor im2col; math::Vol2ColFunctor vol2col; - std::vector dilations({1, 1, 1}); if (input_grad) { input_grad->mutable_data(context.GetPlace()); diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 8d819de603a81de6934dd4b03a6fc49f4319a2d3..1db63fbfe806e8e99be07980c56fcc32018ad156 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -704,6 +704,7 @@ def conv2d_transpose(input, filter_size=None, padding=None, stride=None, + dilation=None, param_attr=None): """ The transpose of conv2d layer. @@ -727,6 +728,9 @@ def conv2d_transpose(input, stride(int|tuple): The stride size. If stride is a tuple, it must contain two integers, (stride_H, stride_W). Otherwise, the stride_H = stride_W = stride. + dilation(int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. param_attr: Parameter Attribute. main_program(Program): the main program startup_program(Program): the startup program @@ -747,10 +751,15 @@ def conv2d_transpose(input, op_attr['paddings'] = padding if isinstance(stride, int): - op_attr['strides'] = stride + op_attr['strides'] = [stride, stride] elif stride is not None: op_attr['strides'] = stride + if isinstance(dilation, int): + op_attr['dilations'] = [dilation, dilation] + elif dilation is not None: + op_attr['dilations'] = dilation + if filter_size is None: if output_size is None: raise ValueError("output_size must be set when filter_size is None") @@ -759,14 +768,17 @@ def conv2d_transpose(input, padding = op_attr.get('paddings', [0, 0]) stride = op_attr.get('strides', [1, 1]) + dilation = op_attr.get('dilations', [1, 1]) h_in = input.shape[2] w_in = input.shape[3] - filter_size_h = output_size[0] - \ - (h_in - 1) * stride[0] + 2 * padding[0] - filter_size_w = output_size[1] - \ - (w_in - 1) * stride[1] + 2 * padding[1] + + filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 * + padding[0] - 1) / dilation[0] + 1 + filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 * + padding[1] - 1) / dilation[1] + 1 filter_size = [filter_size_h, filter_size_w] + elif isinstance(filter_size, int): filter_size = [filter_size, filter_size] diff --git a/python/paddle/v2/fluid/tests/test_conv2d_transpose_op.py b/python/paddle/v2/fluid/tests/test_conv2d_transpose_op.py index d7b1f2f2a3abf6335998742dbbef8e17794170fa..d59537b924d57d40f7d740d99eb814c95f528e5f 100644 --- a/python/paddle/v2/fluid/tests/test_conv2d_transpose_op.py +++ b/python/paddle/v2/fluid/tests/test_conv2d_transpose_op.py @@ -3,14 +3,17 @@ import numpy as np from op_test import OpTest -def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param): +def conv2dtranspose_forward_naive(input_, filter_, attrs): in_n, in_c, in_h, in_w = input_.shape f_c, out_c, f_h, f_w = filter_.shape assert in_c == f_c - stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad'] - out_h = (in_h - 1) * stride[0] + f_h - out_w = (in_w - 1) * stride[1] + f_w + stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ + 'dilations'] + d_bolck_h = dilations[0] * (f_h - 1) + 1 + d_bolck_w = dilations[1] * (f_w - 1) + 1 + out_h = (in_h - 1) * stride[0] + d_bolck_h + out_w = (in_w - 1) * stride[1] + d_bolck_w out = np.zeros((in_n, out_c, out_h, out_w)) @@ -23,9 +26,9 @@ def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param): for k in range(out_c): tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0) - i1, i2 = i * stride[0], i * stride[0] + f_h - j1, j2 = j * stride[0], j * stride[0] + f_w - out[n, k, i1:i2, j1:j2] += tmp_out + i1, i2 = i * stride[0], i * stride[0] + d_bolck_h + j1, j2 = j * stride[0], j * stride[0] + d_bolck_h + out[n, k, i1:i2:dilations[0], j1:j2:dilations[1]] += tmp_out out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]] return out @@ -37,11 +40,8 @@ class TestConv2dTransposeOp(OpTest): self.init_op_type() self.init_test_case() - conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad} input_ = np.random.random(self.input_size).astype("float32") filter_ = np.random.random(self.filter_size).astype("float32") - output = conv2dtranspose_forward_naive( - input_, filter_, conv2dtranspose_param).astype('float32') self.inputs = {'Input': input_, 'Filter': filter_} self.attrs = { @@ -49,6 +49,10 @@ class TestConv2dTransposeOp(OpTest): 'paddings': self.pad, 'dilations': self.dilations } + + output = conv2dtranspose_forward_naive(input_, filter_, + self.attrs).astype('float32') + self.outputs = {'Output': output} def test_check_output(self): @@ -104,11 +108,60 @@ class TestWithStride(TestConv2dTransposeOp): self.filter_size = [f_c, 6, 3, 3] +class TestWithDilation(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + # ------------ test_cudnn ------------ class TestCudnn(TestConv2dTransposeOp): def init_op_type(self): self.op_type = "conv2d_transpose_cudnn" +class TestCudnnWithPad(TestWithPad): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2d_transpose_cudnn" + + +class TestCudnnWithStride(TestWithStride): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2d_transpose_cudnn" + + +# #cudnn v5 does not support dilation conv. +# class TestCudnnWithDilation(TestWithDilation): +# def init_test_case(self): +# self.pad = [1, 1] +# self.stride = [2, 2] +# self.dilations = [2, 2] +# self.input_size = [2, 3, 5, 5] # NCHW +# f_c = self.input_size[1] +# self.filter_size = [f_c, 6, 3, 3] +# +# def init_op_type(self): +# self.op_type = "conv2d_transpose_cudnn" + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py b/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py index 8fd34b87bfea91307f52fdcbb9f71f2e1a9c6c56..a353f9b4d40233de46237005138f21430f4d865a 100644 --- a/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py +++ b/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py @@ -3,15 +3,20 @@ import numpy as np from op_test import OpTest -def conv3dtranspose_forward_naive(input_, filter_, conv3dtranspose_param): +def conv3dtranspose_forward_naive(input_, filter_, attrs): in_n, in_c, in_d, in_h, in_w = input_.shape f_c, out_c, f_d, f_h, f_w = filter_.shape assert in_c == f_c - stride, pad = conv3dtranspose_param['stride'], conv3dtranspose_param['pad'] - out_d = (in_d - 1) * stride[0] + f_d - out_h = (in_h - 1) * stride[1] + f_h - out_w = (in_w - 1) * stride[2] + f_w + stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ + 'dilations'] + + d_bolck_d = dilations[0] * (f_d - 1) + 1 + d_bolck_h = dilations[1] * (f_h - 1) + 1 + d_bolck_w = dilations[2] * (f_w - 1) + 1 + out_d = (in_d - 1) * stride[0] + d_bolck_d + out_h = (in_h - 1) * stride[1] + d_bolck_h + out_w = (in_w - 1) * stride[2] + d_bolck_w out = np.zeros((in_n, out_c, out_d, out_h, out_w)) for n in range(in_n): @@ -25,10 +30,11 @@ def conv3dtranspose_forward_naive(input_, filter_, conv3dtranspose_param): for k in range(out_c): tmp_out = np.sum(input_masked * filter_[:, k, :, :, :], axis=0) - d1, d2 = d * stride[0], d * stride[0] + f_d - i1, i2 = i * stride[1], i * stride[1] + f_h - j1, j2 = j * stride[2], j * stride[2] + f_w - out[n, k, d1:d2, i1:i2, j1:j2] += tmp_out + d1, d2 = d * stride[0], d * stride[0] + d_bolck_d + i1, i2 = i * stride[1], i * stride[1] + d_bolck_h + j1, j2 = j * stride[2], j * stride[2] + d_bolck_w + out[n, k, d1:d2:dilations[0], i1:i2:dilations[1], j1:j2: + dilations[2]] += tmp_out out = out[:, :, pad[0]:out_d - pad[0], pad[1]:out_h - pad[1], pad[2]:out_w - pad[2]] @@ -41,18 +47,19 @@ class TestConv3dTransposeOp(OpTest): self.init_op_type() self.init_test_case() - conv3dtranspose_param = {'stride': self.stride, 'pad': self.pad} input_ = np.random.random(self.input_size).astype("float32") filter_ = np.random.random(self.filter_size).astype("float32") - output = conv3dtranspose_forward_naive( - input_, filter_, conv3dtranspose_param).astype("float32") self.inputs = {'Input': input_, 'Filter': filter_} self.attrs = { 'strides': self.stride, 'paddings': self.pad, - # 'dilations': self.dilations + 'dilations': self.dilations } + + output = conv3dtranspose_forward_naive(input_, filter_, + self.attrs).astype("float32") + self.outputs = {'Output': output} def test_check_output(self): @@ -108,11 +115,60 @@ class TestWithStride(TestConv3dTransposeOp): self.filter_size = [f_c, 6, 3, 3, 3] +class TestWithDilation(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [2, 2, 2] + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + + # ------------ test_cudnn ------------ class TestCudnn(TestConv3dTransposeOp): def init_op_type(self): self.op_type = "conv3d_transpose_cudnn" +class TestCudnnWithPad(TestWithPad): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + + def init_op_type(self): + self.op_type = "conv3d_transpose_cudnn" + + +class TestCudnnWithStride(TestWithStride): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [2, 2, 2] + self.dilations = [1, 1, 1] + self.input_size = [2, 3, 5, 5, 5] # NCDHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3, 3] + + def init_op_type(self): + self.op_type = "conv3d_transpose_cudnn" + + +# #cudnn v5 does not support dilation conv. +# class TestCudnnWithDilation(TestWithDilation): +# def init_test_case(self): +# self.pad = [1, 1, 1] +# self.stride = [2, 2, 2] +# self.dilations = [2, 2, 2] +# self.input_size = [2, 3, 5, 5, 5] # NCDHW +# f_c = self.input_size[1] +# self.filter_size = [f_c, 6, 3, 3, 3] +# +# def init_op_type(self): +# self.op_type = "conv3d_transpose_cudnn" + if __name__ == '__main__': unittest.main()