From b756063ce7e71528d57c67caa94871bd924729d9 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 27 Jun 2018 15:03:29 +0800 Subject: [PATCH] Speed depthwise transposed conv2d. (#11740) * Speed depthwise transposed conv2d. --- paddle/fluid/operators/conv_transpose_op.cc | 18 +++++ .../fluid/operators/conv_transpose_op.cu.cc | 45 ++++++------ paddle/fluid/operators/conv_transpose_op.h | 70 +++++++++++++++++++ python/paddle/fluid/layers/nn.py | 13 +++- .../unittests/test_conv2d_transpose_op.py | 13 ++++ 5 files changed, 135 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 2e9e957eb..eeb98ee44 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -302,6 +302,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( namespace ops = paddle::operators; +// conv2d_transpose REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, paddle::framework::DefaultGradOpDescMaker); @@ -317,6 +318,7 @@ REGISTER_OP_CPU_KERNEL( ops::GemmConvTransposeGradKernel); +// conv3d_transpose REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, paddle::framework::DefaultGradOpDescMaker); @@ -331,3 +333,19 @@ REGISTER_OP_CPU_KERNEL( ops::GemmConvTransposeGradKernel, ops::GemmConvTransposeGradKernel); + +// depthwise conv2d_transpose +REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp, + ops::Conv2DTransposeOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + depthwise_conv2d_transpose, + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); +REGISTER_OP_CPU_KERNEL( + depthwise_conv2d_transpose_grad, + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/fluid/operators/conv_transpose_op.cu.cc b/paddle/fluid/operators/conv_transpose_op.cu.cc index 640fa7d14..a6d5665df 100644 --- a/paddle/fluid/operators/conv_transpose_op.cu.cc +++ b/paddle/fluid/operators/conv_transpose_op.cu.cc @@ -15,25 +15,28 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_transpose_op.h" namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL( - conv2d_transpose, - ops::GemmConvTransposeKernel, - ops::GemmConvTransposeKernel); -REGISTER_OP_CUDA_KERNEL( - conv2d_transpose_grad, - ops::GemmConvTransposeGradKernel, - ops::GemmConvTransposeGradKernel); - -REGISTER_OP_CUDA_KERNEL( - conv3d_transpose, - ops::GemmConvTransposeKernel, - ops::GemmConvTransposeKernel); -REGISTER_OP_CUDA_KERNEL( - conv3d_transpose_grad, - ops::GemmConvTransposeGradKernel, - ops::GemmConvTransposeGradKernel); +// conv2d +REGISTER_OP_CUDA_KERNEL(conv2d_transpose, + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); +REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad, + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); + +// conv3d +REGISTER_OP_CUDA_KERNEL(conv3d_transpose, + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); +REGISTER_OP_CUDA_KERNEL(conv3d_transpose_grad, + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); + +// depthwise conv2d +REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose, + ops::DepthwiseConvTransposeKernel, + ops::DepthwiseConvTransposeKernel); +REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose_grad, + ops::DepthwiseConvTransposeGradKernel, + ops::DepthwiseConvTransposeGradKernel); diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 1dcfc651f..0d9c6a62f 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/vol2col.h" @@ -316,5 +317,74 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { } } }; + +template +class DepthwiseConvTransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + int groups = context.Attr("groups"); + PADDLE_ENFORCE_EQ(groups, filter.dims()[0]); + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + for (auto v : dilations) { + PADDLE_ENFORCE_EQ(v, 1); + } + + output->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + math::SetConstant set_zero; + set_zero(dev_ctx, output, static_cast(0)); + + math::DepthwiseConvInputGradFunctor + depthwiseConvInputGrad; + depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings, + output); + } +}; + +template +class DepthwiseConvTransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + Tensor filter = *context.Input("Filter"); + + if (!input_grad && !filter_grad) return; + + auto& dev_ctx = context.template device_context(); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + if (input_grad) { + math::DepthwiseConvFunctor depthwiseConv; + depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings, + input_grad); + } + + if (filter_grad) { + math::SetConstant set_zero; + filter_grad->mutable_data(context.GetPlace()); + set_zero(dev_ctx, filter_grad, static_cast(0)); + + math::DepthwiseConvFilterGradFunctor + depthwiseConvFilterGrad; + depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings, + filter_grad); + } + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f5700ed56..02ea2af32 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2334,10 +2334,17 @@ def conv2d_transpose(input, data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32') conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3) """ - helper = LayerHelper("conv2d_transpose", **locals()) + + input_channel = input.shape[1] + + op_type = 'conv2d_transpose' + if (input_channel == groups and num_filters == input_channel and + not use_cudnn): + op_type = 'depthwise_conv2d_transpose' + + helper = LayerHelper(op_type, **locals()) if not isinstance(input, Variable): raise TypeError("Input of conv2d_transpose must be Variable") - input_channel = input.shape[1] padding = utils.convert_to_list(padding, 2, 'padding') stride = utils.convert_to_list(stride, 2, 'stride') @@ -2371,7 +2378,7 @@ def conv2d_transpose(input, pre_bias = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( - type='conv2d_transpose', + type=op_type, inputs={'Input': [input], 'Filter': [img_filter]}, outputs={'Output': pre_bias}, diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py index ded2f1302..07545e7fe 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_op.py @@ -242,6 +242,19 @@ class TestCUDNNWithGroups(TestWithGroups): self.op_type = "conv2d_transpose" +class TestDepthwiseConvTranspose(TestConv2dTransposeOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.dilations = [1, 1] + self.input_size = [2, 8, 16, 16] # NCHW + self.groups = 8 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [self.input_size[1], f_c, 4, 4] + self.op_type = "depthwise_conv2d_transpose" + + # Please Don't remove the following code. # Currently, CI use cudnn V5.0 which not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation): -- GitLab