diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc index 901682edbb01c563be6ea407228336b14f942778..038ea8999072f562104c5386ed18b6b275816345 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc @@ -44,6 +44,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); // cudnn v5 does not support dilations std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); int user_workspace_size = ctx.Attr("workspace_size_MB"); const T* input_data = input->data(); @@ -64,13 +65,13 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { // (N, M, H, W) or (N, M, D, H, W) cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, framework::vectorize2int(input->dims())); + layout, framework::vectorize2int(input->dims()), groups); // (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - layout, framework::vectorize2int(output->dims())); + layout, framework::vectorize2int(output->dims()), groups); // (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w) cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( - layout, framework::vectorize2int(filter->dims())); + layout, framework::vectorize2int(filter->dims()), groups); cudnnConvolutionDescriptor_t cudnn_conv_desc = conv_desc.descriptor(paddings, strides, dilations); @@ -104,11 +105,17 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv transpose forward --------------------- + int input_offset = input->numel() / input->dims()[0] / groups; + int output_offset = output->numel() / output->dims()[0] / groups; + int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; - PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( - handle, &alpha, cudnn_filter_desc, filter_data, cudnn_input_desc, - input_data, cudnn_conv_desc, algo, cudnn_workspace, - workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); + for (int g = 0; g < groups; g++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( + handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, + cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, + algo, cudnn_workspace, workspace_size_in_bytes, &beta, + cudnn_output_desc, output_data + output_offset * g)); + } // Release the cudnn workspace paddle::memory::Free(gpu, cudnn_workspace); @@ -134,6 +141,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); // cudnn v5 does not support dilations std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); int user_workspace_size = ctx.Attr("workspace_size_MB"); // ------------------- cudnn descriptors --------------------- @@ -145,13 +153,13 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { // Input: (N, M, H, W) or (N, M, D, H, W) cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, framework::vectorize2int(input->dims())); + layout, framework::vectorize2int(input->dims()), groups); // Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - layout, framework::vectorize2int(output_grad->dims())); + layout, framework::vectorize2int(output_grad->dims()), groups); // Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w) cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( - layout, framework::vectorize2int(filter->dims())); + layout, framework::vectorize2int(filter->dims()), groups); cudnnConvolutionDescriptor_t cudnn_conv_desc = conv_desc.descriptor(paddings, strides, dilations); @@ -205,15 +213,22 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- // FIXME(typhoonzero): template type T may not be the same as cudnn call. + int input_offset = input->numel() / input->dims()[0] / groups; + int output_grad_offset = + output_grad->numel() / output_grad->dims()[0] / groups; + int filter_offset = filter->numel() / groups; T alpha = 1.0f, beta = 0.0f; if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. - PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( - handle, &alpha, cudnn_output_desc, output_grad_data, - cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo, - cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, - input_grad_data)); + for (int g = 0; g < groups; g++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_filter_desc, + filter_data + filter_offset * g, cudnn_conv_desc, data_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, + input_grad_data + input_offset * g)); + } } // ------------------- cudnn conv backward filter --------------------- @@ -221,11 +236,16 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset filter_grad. // Gradient with respect to the filter - PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( - handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc, - input_data, cudnn_conv_desc, filter_algo, cudnn_workspace, - workspace_size_in_bytes, &beta, cudnn_filter_desc, filter_grad_data)); + for (int g = 0; g < groups; g++) { + PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( + handle, &alpha, cudnn_output_desc, + output_grad_data + output_grad_offset * g, cudnn_input_desc, + input_data + input_offset * g, cudnn_conv_desc, filter_algo, + cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc, + filter_grad_data + filter_offset * g)); + } } + // Release the cudnn workspace paddle::memory::Free(gpu, cudnn_workspace); } diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index c27c8e273168407d3aacb05cd6628887cc5760ad..0b363f5c43f9fc191790e5cca629ffc46eb9388c 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -32,6 +32,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); std::vector dilations = ctx->Attrs().Get>("dilations"); + int groups = ctx->Attrs().Get("groups"); PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, "ConvTransposeOp intput should be 4-D or 5-D tensor."); @@ -48,10 +49,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "ConvTransposeOp paddings dimension and dilations " "dimension should be the same."); PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], - "In ConvTransposeOp, The input channel should be the same " - "as the number of filters."); + "In ConvTransposeOp, The number of input channels should " + "be equal to the number of filter's channels."); - std::vector output_shape({in_dims[0], filter_dims[1]}); + std::vector output_shape({in_dims[0], filter_dims[1] * groups}); for (size_t i = 0; i < strides.size(); ++i) { auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] + @@ -102,7 +103,10 @@ void Conv2DTransposeOpMaker::Make() { AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator. " "The format of output tensor is also NCHW."); - + AddAttr("groups", + "(int default:1), the groups number of the convolution " + "transpose operator. ") + .SetDefault(1); AddAttr>("dilations", "(vector default:{1, 1}), the " "dilations(h_dilation, w_dilation) of convolution " @@ -204,6 +208,10 @@ void Conv3DTransposeOpMaker::Make() { "(vector default:{0, 0, 0}), paddings(d_pad, " "h_pad, w_pad) of convolution transpose operator.") .SetDefault({0, 0, 0}); + AddAttr("groups", + "(int default:1), the groups number of the convolution3d " + "transpose operator. ") + .SetDefault(1); AddAttr( "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index f9d205a5b5c4cff74d02a6c89b83f7584e4a6824..1dcfc651fdd79aed50736d05d38ec8576b183d41 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -70,7 +70,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); - // groups will alway be disabled in conv2dtranspose. + int groups = context.Attr("groups"); const int batch_size = static_cast(input->dims()[0]); @@ -81,10 +81,10 @@ class GemmConvTransposeKernel : public framework::OpKernel { // use col_shape in the im2col and col2im (or vol2col and col2vol) // calculation - // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} + // col_shape_vec: {c/g, k_h, k_w, h, w} or {c/g, k_d, k_h, k_w, d, h, w} size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = output->dims()[1]; + col_shape_vec[0] = output->dims()[1] / groups; for (size_t j = 0; j < data_dim; ++j) { col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; @@ -92,7 +92,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation - // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) + // size: (c/g * k_h * k_w, h * w) or (c/g * k_d * k_h * k_w, d * h * w) DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); Tensor col; @@ -111,7 +111,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // input matrix size: (m, h * w) or (m, d * h * w) DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; - // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) + // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w) DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; filter.Resize(filter_matrix_shape); @@ -121,6 +121,8 @@ class GemmConvTransposeKernel : public framework::OpKernel { auto blas = math::GetBlas(dev_ctx); set_zero(dev_ctx, output, static_cast(0)); + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; math::Col2ImFunctor col2im; math::Col2VolFunctor col2vol; @@ -133,22 +135,29 @@ class GemmConvTransposeKernel : public framework::OpKernel { // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); - // col_matrix = filter * input_batch - // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) - blas.MatMul(filter, true, input_batch, false, static_cast(1.0), - &col_matrix, static_cast(0.0)); - - if (data_dim == 2U) { - // col2im: col_matrix -> dy - // from (c * k_h * k_w, h * w) to (c, o_h, o_w) - col2im(dev_ctx, col, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &output_batch); - } else if (data_dim == 3U) { - // col2vol: col_matrix -> dy - // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) - col2vol(dev_ctx, col, dilations, strides, paddings, &output_batch); + for (int g = 0; g < groups; g++) { + Tensor in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step); + Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); + Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step); + + // col_matrix = filter_slice * input_slice + // of shape (c/g * k_h * k_w, h * w) + // or (c/g * k_d * k_h * k_w, d * h * w) + blas.MatMul(filter_slice, true, in_slice, false, static_cast(1.0), + &col_matrix, static_cast(0.0)); + + if (data_dim == 2U) { + // col2im: col_matrix -> dy + // from (c/g * k_h * k_w, h * w) to (c/g, o_h, o_w) + col2im(dev_ctx, col, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &out_slice); + } else if (data_dim == 3U) { + // col2vol: col_matrix -> dy + // from (c/g * k_d * k_h * k_w, d * h * w) to (c/g, o_d, o_h, o_w) + col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice); + } } } } @@ -174,6 +183,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + int groups = context.Attr("groups"); const int batch_size = static_cast(input->dims()[0]); @@ -205,9 +215,11 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // input matrix size: (m, h * w) or (m, d * h * w) DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; - // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) - DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; + // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w) + DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0] / groups}; filter.Resize(filter_matrix_shape); + int in_step = static_cast(input->dims()[1]) / groups; + int col_step = static_cast(col_matrix_shape[0]) / groups; // convolution transpose grad on input: // im2col + gemm (similar to conv-forward) @@ -233,7 +245,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { if (input_grad) { input_grad->mutable_data(context.GetPlace()); } - if (filter_grad) { // filter size (m, c, k_h, k_w) + if (filter_grad) { // filter size (m, c/g, k_h, k_w) filter_grad->mutable_data(context.GetPlace()); set_zero(dev_ctx, filter_grad, static_cast(0)); filter_grad_ = *filter_grad; @@ -268,8 +280,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // d, h, w) - blas.MatMul(filter, false, col_matrix, false, static_cast(1.0), - &input_grad_batch, static_cast(0.0)); + for (int g = 0; g < groups; g++) { + Tensor input_grad_slice = + input_grad_batch.Slice(g * in_step, (g + 1) * in_step); + Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); + Tensor col_matrix_slice = + col_matrix.Slice(g * col_step, (g + 1) * col_step); + + blas.MatMul(filter_slice, false, col_matrix_slice, false, + static_cast(1.0), &input_grad_slice, + static_cast(0.0)); + } } if (filter_grad) { // input batch @@ -279,8 +300,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // k_h * k_w) - blas.MatMul(in_batch, false, col_matrix, true, static_cast(1.0), - &filter_grad_, static_cast(1.0)); + for (int g = 0; g < groups; g++) { + Tensor in_batch_slice = + in_batch.Slice(g * in_step, (g + 1) * in_step); + Tensor filter_grad_slice = + filter_grad_.Slice(g * in_step, (g + 1) * in_step); + Tensor col_matrix_slice = + col_matrix.Slice(g * col_step, (g + 1) * col_step); + blas.MatMul(in_batch_slice, false, col_matrix_slice, true, + static_cast(1.0), &filter_grad_slice, + static_cast(1.0)); + } } } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1f2e483a0968308063710d3081fe0ddc7b559d75..dd360c2b98414d1cf2c0da5f7c8d5c6ca461a22a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1708,6 +1708,7 @@ def conv2d_transpose(input, padding=0, stride=1, dilation=1, + groups=None, param_attr=None, bias_attr=None, use_cudnn=True, @@ -1778,6 +1779,12 @@ def conv2d_transpose(input, dilation(int|tuple): The dilation size. If dilation is a tuple, it must contain two integers, (dilation_H, dilation_W). Otherwise, the dilation_H = dilation_W = dilation. Default: dilation = 1. + groups(int): The groups number of the Conv2d transpose layer. Inspired by + grouped convolution in Alex Krizhevsky's Deep CNN paper, in which + when group=2, the first half of the filters is only connected to the + first half of the input channels, while the second half of the + filters is only connected to the second half of the input channels. + Default: groups=1 param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer. Default: None bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None @@ -1832,7 +1839,8 @@ def conv2d_transpose(input, filter_size = utils.convert_to_list(filter_size, 2, 'conv2d_transpose.filter_size') - filter_shape = [input_channel, num_filters] + filter_size + groups = 1 if groups is None else groups + filter_shape = [input_channel, num_filters / groups] + filter_size img_filter = helper.create_parameter( dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index d864b9b348e961c585749d47d449d775b2dfebc9..ded2f130288a4a959a1c859b2cc8ccf0912efb12 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -21,8 +21,11 @@ from op_test import OpTest def conv2dtranspose_forward_naive(input_, filter_, attrs): in_n, in_c, in_h, in_w = input_.shape - f_c, out_c, f_h, f_w = filter_.shape + f_c, f_out_c, f_h, f_w = filter_.shape + groups = attrs['groups'] assert in_c == f_c + out_c = f_out_c * groups + sub_in_c = in_c / groups stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ 'dilations'] @@ -36,15 +39,21 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): for n in range(in_n): for i in range(in_h): for j in range(in_w): - input_masked = input_[n, :, i, j] # (c) - input_masked = np.reshape(input_masked, (in_c, 1, 1)) - input_masked = np.tile(input_masked, (1, f_h, f_w)) - - for k in range(out_c): - tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0) - i1, i2 = i * stride[0], i * stride[0] + d_bolck_h - j1, j2 = j * stride[0], j * stride[0] + d_bolck_h - out[n, k, i1:i2:dilations[0], j1:j2:dilations[1]] += tmp_out + for g in range(groups): + input_masked = input_[n, g * sub_in_c:(g + 1) * sub_in_c, i, + j] # (c) + input_masked = np.reshape(input_masked, (sub_in_c, 1, 1)) + input_masked = np.tile(input_masked, (1, f_h, f_w)) + + for k in range(f_out_c): + tmp_out = np.sum( + input_masked * + filter_[g * sub_in_c:(g + 1) * sub_in_c, k, :, :], + axis=0) + i1, i2 = i * stride[0], i * stride[0] + d_bolck_h + j1, j2 = j * stride[0], j * stride[0] + d_bolck_h + out[n, g * f_out_c + k, i1:i2:dilations[0], j1:j2: + dilations[1]] += tmp_out out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]] return out @@ -64,6 +73,7 @@ class TestConv2dTransposeOp(OpTest): self.attrs = { 'strides': self.stride, 'paddings': self.pad, + 'groups': self.groups, 'dilations': self.dilations, 'use_cudnn': self.use_cudnn, 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter @@ -127,6 +137,7 @@ class TestConv2dTransposeOp(OpTest): self.pad = [0, 0] self.stride = [1, 1] self.dilations = [1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5] # NCHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3] @@ -140,16 +151,29 @@ class TestWithPad(TestConv2dTransposeOp): self.pad = [1, 1] self.stride = [1, 1] self.dilations = [1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5] # NCHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3] +class TestWithGroups(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3] + + class TestWithStride(TestConv2dTransposeOp): def init_test_case(self): self.pad = [1, 1] self.stride = [2, 2] self.dilations = [1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5] # NCHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3] @@ -159,6 +183,7 @@ class TestWithDilation(TestConv2dTransposeOp): def init_test_case(self): self.pad = [1, 1] self.stride = [1, 1] + self.groups = 1 self.dilations = [2, 2] self.input_size = [2, 3, 5, 5] # NCHW f_c = self.input_size[1] @@ -176,6 +201,7 @@ class TestCUDNNWithPad(TestWithPad): def init_test_case(self): self.pad = [1, 1] self.stride = [1, 1] + self.groups = 1 self.dilations = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW f_c = self.input_size[1] @@ -190,6 +216,7 @@ class TestCUDNNWithStride(TestWithStride): def init_test_case(self): self.pad = [1, 1] self.stride = [2, 2] + self.groups = 1 self.dilations = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW f_c = self.input_size[1] @@ -200,6 +227,21 @@ class TestCUDNNWithStride(TestWithStride): self.op_type = "conv2d_transpose" +class TestCUDNNWithGroups(TestWithGroups): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3] + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d_transpose" + + # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation): diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py index 55ba238710c56dd0daea388cd2dcdb79243bb71e..c9f26d10df8ff39d6bd77b1597336600f676d362 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_op.py @@ -21,8 +21,11 @@ from op_test import OpTest def conv3dtranspose_forward_naive(input_, filter_, attrs): in_n, in_c, in_d, in_h, in_w = input_.shape - f_c, out_c, f_d, f_h, f_w = filter_.shape + f_c, f_out_c, f_d, f_h, f_w = filter_.shape + groups = attrs['groups'] assert in_c == f_c + out_c = f_out_c * groups + sub_in_c = in_c / groups stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ 'dilations'] @@ -39,18 +42,23 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs): for d in range(in_d): for i in range(in_h): for j in range(in_w): - input_masked = input_[n, :, d, i, j] # (c) - input_masked = np.reshape(input_masked, (in_c, 1, 1, 1)) - input_masked = np.tile(input_masked, (1, f_d, f_h, f_w)) - - for k in range(out_c): - tmp_out = np.sum(input_masked * filter_[:, k, :, :, :], - axis=0) - d1, d2 = d * stride[0], d * stride[0] + d_bolck_d - i1, i2 = i * stride[1], i * stride[1] + d_bolck_h - j1, j2 = j * stride[2], j * stride[2] + d_bolck_w - out[n, k, d1:d2:dilations[0], i1:i2:dilations[1], j1:j2: - dilations[2]] += tmp_out + for g in range(groups): + input_masked = input_[n, g * sub_in_c:(g + 1 + ) * sub_in_c, d, + i, j] # (c) + input_masked = np.reshape(input_masked, + (sub_in_c, 1, 1, 1)) + input_masked = np.tile(input_masked, (1, f_d, f_h, f_w)) + + for k in range(f_out_c): + tmp_out = np.sum(input_masked * filter_[ + g * sub_in_c:(g + 1) * sub_in_c, k, :, :, :], + axis=0) + d1, d2 = d * stride[0], d * stride[0] + d_bolck_d + i1, i2 = i * stride[1], i * stride[1] + d_bolck_h + j1, j2 = j * stride[2], j * stride[2] + d_bolck_w + out[n, g * f_out_c + k, d1:d2:dilations[0], i1:i2: + dilations[1], j1:j2:dilations[2]] += tmp_out out = out[:, :, pad[0]:out_d - pad[0], pad[1]:out_h - pad[1], pad[2]:out_w - pad[2]] @@ -72,6 +80,7 @@ class TestConv3dTransposeOp(OpTest): 'strides': self.stride, 'paddings': self.pad, 'dilations': self.dilations, + 'groups': self.groups, 'use_cudnn': self.use_cudnn, 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter } @@ -134,6 +143,7 @@ class TestConv3dTransposeOp(OpTest): self.pad = [0, 0, 0] self.stride = [1, 1, 1] self.dilations = [1, 1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5, 5] # NCDHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3, 3] @@ -147,16 +157,29 @@ class TestWithPad(TestConv3dTransposeOp): self.pad = [1, 1, 1] self.stride = [1, 1, 1] self.dilations = [1, 1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5, 5] # NCDHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3, 3] +class TestWithGroups(TestConv3dTransposeOp): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3, 3] + + class TestWithStride(TestConv3dTransposeOp): def init_test_case(self): self.pad = [1, 1, 1] self.stride = [2, 2, 2] self.dilations = [1, 1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5, 5] # NCDHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3, 3] @@ -167,6 +190,7 @@ class TestWithDilation(TestConv3dTransposeOp): self.pad = [1, 1, 1] self.stride = [1, 1, 1] self.dilations = [2, 2, 2] + self.groups = 1 self.input_size = [2, 3, 5, 5, 5] # NCDHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3, 3] @@ -184,6 +208,7 @@ class TestCUDNNWithPad(TestWithPad): self.pad = [1, 1, 1] self.stride = [1, 1, 1] self.dilations = [1, 1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5, 5] # NCDHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3, 3] @@ -198,6 +223,7 @@ class TestCUDNNWithStride(TestWithStride): self.pad = [1, 1, 1] self.stride = [2, 2, 2] self.dilations = [1, 1, 1] + self.groups = 1 self.input_size = [2, 3, 5, 5, 5] # NCDHW f_c = self.input_size[1] self.filter_size = [f_c, 6, 3, 3, 3] @@ -207,6 +233,21 @@ class TestCUDNNWithStride(TestWithStride): self.op_type = "conv3d_transpose" +class TestCUDNNWithGroups(TestWithGroups): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.dilations = [1, 1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3, 3] + + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv3d_transpose" + + # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation):