From fb46345f007e7c989d8c5d635dc0ff9d24bbbf31 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 13 Sep 2017 14:15:58 +0800 Subject: [PATCH] Add groups in convolution operator. --- paddle/operators/conv_op.cc | 22 ++++++++++++++++++-- paddle/operators/gemm_conv_op.h | 36 ++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 107682848..174f777f0 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -31,12 +31,22 @@ class Conv2DOp : public framework::OperatorWithKernel { auto in = ctx.Input("Input"); auto filter = ctx.Input("Filter"); auto out = ctx.Output("Output"); + std::vector strides = Attr>("strides"); + std::vector paddings = Attr>("paddings"); + int groups = context.Attr("groups"); + int input_channels = in->dims()[1]; + int output_channels = filter->dims()[0]; + PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp intput should be 4-D."); PADDLE_ENFORCE_EQ(filter->dims().size(), 4, "Conv2DOp filter should be 4-D."); + PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + output_channels % groups, 0, + "The number of output channels should be divided by groups."); - std::vector strides = Attr>("strides"); - std::vector paddings = Attr>("paddings"); auto output_height = outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); auto output_width = @@ -71,6 +81,14 @@ the input, filter and strides, paddings parameters. )DOC"); AddAttr>("strides", "strides of convolution operator."); AddAttr>("paddings", "paddings of convolution operator."); + AddAttr( + "groups", + "group size of convolution operator. " + "Refer to grouped convolution in Alex Krizhevsky's paper: " + "when group=2, the first half of the filters are only connected to the " + "first half of the input channels, and the second half only connected " + "to the second half.") + .SetDefault(1); } }; diff --git a/paddle/operators/gemm_conv_op.h b/paddle/operators/gemm_conv_op.h index 3b7ba685c..8ac92d3bd 100644 --- a/paddle/operators/gemm_conv_op.h +++ b/paddle/operators/gemm_conv_op.h @@ -38,6 +38,7 @@ class GemmConvKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + int groups = context.Attr("groups"); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; @@ -51,11 +52,11 @@ class GemmConvKernel : public framework::OpKernel { paddle::operators::math::ColFormat::kCFO, Place, T> im2col; // use col_shape in the im2col calculation - framework::DDim col_shape = {input_channels, filter_height, filter_width, - output_height, output_width}; + framework::DDim col_shape = {input_channels / groups, filter_height, + filter_width, output_height, output_width}; // use col_matrix_shape in the gemm calculation framework::DDim col_matrix_shape = { - input_channels * filter_height * filter_width, + input_channels / groups * filter_height * filter_width, output_height * output_width}; Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -78,16 +79,26 @@ class GemmConvKernel : public framework::OpKernel { const_cast(context.device_context_); // convolution operator: im2col + gemm + int in_step = input_channels / groups; + int out_step = output_channels / groups; for (int i = 0; i < batch_size; i++) { - // im2col - Tensor in_slice = input->Slice(i, i + 1).Resize(input_shape); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); - - // gemm - Tensor out_slice = output->Slice(i, i + 1).Resize(output_matrix_shape); - math::matmul(filter, false, col_matrix, false, T(1.0), - &out_slice, T(0.0), device_context); + Tensor in_slice_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_slice_batch = + output->Slice(i, i + 1).Resize(output_matrix_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor in_slice = + in_slice_batch.Slice(g * in_step, (g + 1) * in_step); + im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], + device_context); + + // gemm + Tensor out_slice = + out_slice_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, false, col_matrix, false, T(1.0), + &out_slice, T(0.0), device_context); + } } } }; @@ -114,6 +125,7 @@ class GemmConvGradKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + // int groups = context.Attr("groups"); int batch_size = input->dims()[0]; int input_channels = input->dims()[1]; -- GitLab