diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e903f43ba69ee9e28b3a03e8921a41ffa81a2542..000c2089c176adf8d845a56a1f98528734f47ea1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -158,7 +158,10 @@ op_library(parallel_do_op DEPS executor) # Regist multiple Kernel to pybind if (WITH_GPU) -op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col) + +op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS + vol2col depthwise_conv) + op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function) op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling) op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index d6882b275b22b9a2a2b6ff8cfb53a3462bbdbefe..cef7ddd5fe7e12a374fb9cc79211bd2eb97c6c52 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -318,9 +318,25 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( namespace ops = paddle::operators; REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, ops::ConvOpGrad); + +// depthwise convolution op +REGISTER_OP(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, + depthwise_conv2d_grad, ops::ConvOpGrad); REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); +// depthwise conv kernel +// TODO(xingzhaolong): neon kernel for mobile +REGISTER_OP_CPU_KERNEL( + depthwise_conv2d, + ops::GemmConvKernel, + ops::GemmConvKernel); + +REGISTER_OP_CPU_KERNEL( + depthwise_conv2d_grad, + ops::GemmConvGradKernel, + ops::GemmConvGradKernel); + REGISTER_OP_CPU_KERNEL( conv2d, ops::GemmConvKernel, ops::GemmConvKernel); diff --git a/paddle/operators/conv_op.cu.cc b/paddle/operators/conv_op.cu.cc index 4f942444f3eb5584f07399b8d1b4d6a5087496d4..d0bd40ee95dab3b2589742b8a0c3a5de7918b5b9 100644 --- a/paddle/operators/conv_op.cu.cc +++ b/paddle/operators/conv_op.cu.cc @@ -16,6 +16,16 @@ limitations under the License. */ namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + depthwise_conv2d, + ops::DepthwiseConvKernel, + ops::DepthwiseConvKernel); + +REGISTER_OP_CUDA_KERNEL( + depthwise_conv2d_grad, + ops::DepthwiseConvGradKernel, + ops::DepthwiseConvGradKernel); + REGISTER_OP_CUDA_KERNEL( conv2d, ops::GemmConvKernel, ops::GemmConvKernel); diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 5a8933e7915960f9fcbe92ae73c4f37b3b69ecaf..3c1d0e9c1c4bb964bfaebc3bfed115548bd53f97 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/depthwise_conv.h" #include "paddle/operators/math/im2col.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/vol2col.h" @@ -350,5 +351,72 @@ class GemmConvGradKernel : public framework::OpKernel { } } }; + +template +class DepthwiseConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + Tensor filter = *context.Input("Filter"); + Tensor* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + PADDLE_ENFORCE_EQ( + output->dims()[1] % input->dims()[1], 0, + "The output channels must be a multiple of the input channels"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + + math::DepthwiseConvFunctor depthwiseConv; + + auto& dev_ctx = context.template device_context(); + depthwiseConv(dev_ctx, *input, filter, strides, paddings, output); + } +}; + +template +class DepthwiseConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + Tensor filter = *context.Input("Filter"); + + if (!input_grad && !filter_grad) return; + + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + std::vector dilations = context.Attr>("dilations"); + + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + + math::DepthwiseConvInputGradFunctor + depthwiseConvInputGrad; + math::DepthwiseConvFilterGradFunctor + depthwiseConvFilterGrad; + + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + set_zero(dev_ctx, input_grad, static_cast(0)); + depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides, + paddings, input_grad); + } + + if (filter_grad) { + filter_grad->mutable_data(context.GetPlace()); + set_zero(dev_ctx, filter_grad, static_cast(0)); + depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, paddings, + filter_grad); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 28c5aec1996ad04a6cb551ac68c14b613d16858e..768106fadf355ea6fb148491e232dc0ef1453a75 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -8,6 +8,7 @@ if(WITH_GPU) nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) + nv_library(depthwise_conv SRCS depthwise_conv.cu DEPS device_context) nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) diff --git a/paddle/operators/math/depthwise_conv.cu b/paddle/operators/math/depthwise_conv.cu new file mode 100644 index 0000000000000000000000000000000000000000..b212e78208355866516211d276cb8046623babd7 --- /dev/null +++ b/paddle/operators/math/depthwise_conv.cu @@ -0,0 +1,311 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/depthwise_conv.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +// A Cuda kernel to compute the depthwise convolution forward pass +// in NCHW format. +template +__global__ void KernelDepthwiseConv( + const int nthreads, const T* const input_data, const T* const filter_data, + const int batch_size, const int output_channels, const int output_height, + const int output_width, const int input_channels, const int input_height, + const int input_width, const int filter_multiplier, const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, T* const output_data) { + int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + if (index < nthreads) { + const int batch = index / output_channels / output_height / output_width; + const int c_out = (index / output_height / output_width) % output_channels; + const int h_out = (index / output_width) % output_height; + const int w_out = index % output_width; + + const int c_in = c_out / filter_multiplier; + const T* weight = filter_data + c_out * filter_height * filter_width; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + filter_height; + const int w_in_end = w_in_start + filter_width; + + const int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + + for (int h_in = h_start; h_in < h_end; h_in++) { + for (int w_in = w_start; w_in < w_end; w_in++) { + const int offset = in_offset + h_in * input_width + w_in; + value += + weight[(h_in - h_in_start) * filter_width + (w_in - w_in_start)] * + input_data[offset]; + } + } + output_data[index] = value; + } +} + +// CUDA kernel to compute the depthwise convolution backprop w.r.t input. +template +__global__ void KernelDepthwiseConvInputGrad( + const int nthreads, const T* const output_grad_data, + const T* const filter_data, const int batch_size, const int output_channels, + const int output_height, const int output_width, const int input_channels, + const int input_height, const int input_width, const int filter_multiplier, + const int filter_height, const int filter_width, const int stride_height, + const int stride_width, const int padding_height, const int padding_width, + T* const input_grad_data) { + int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int batch = index / input_channels / input_height / input_width; + const int c_in = (index / input_height / input_width) % input_channels; + const int h_in = (index / input_width) % input_height; + const int w_in = index % input_width; + + const int c_out_start = c_in * filter_multiplier; + + int h_out_start = + (h_in - filter_height + padding_height + stride_height) / stride_height; + h_out_start = 0 > h_out_start ? 0 : h_out_start; + + int h_out_end = (h_in + padding_height) / stride_height; + h_out_end = output_height - 1 < h_out_end ? output_height - 1 : h_out_end; + + int w_out_start = + (w_in - filter_width + padding_width + stride_width) / stride_width; + w_out_start = 0 > w_out_start ? 0 : w_out_start; + + int w_out_end = (w_in + padding_width) / stride_width; + w_out_end = output_width - 1 < w_out_end ? output_width - 1 : w_out_end; + + T value = 0; + + for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; + c_out++) { + for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) { + const int filter_h = h_in + padding_height - h_out * stride_height; + for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) { + const int filter_w = w_in + padding_width - w_out * stride_width; + const int filter_offset = c_out * filter_height * filter_width + + filter_h * filter_width + filter_w; + const int output_grad_offset = + ((batch * output_channels + c_out) * output_height + h_out) * + output_width + + w_out; + value += + output_grad_data[output_grad_offset] * filter_data[filter_offset]; + } + } + } + input_grad_data[index] += value; + } +} + +// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. +template +__global__ void KernelDepthwiseConvFilterGrad( + const int nthreads, const T* const output_grad_data, + const T* const input_data, const int num, const int output_channels, + const int output_height, const int output_width, const int input_channels, + const int input_height, const int input_width, const int filter_multiplier, + const int filter_height, const int filter_width, const int stride_height, + const int stride_width, const int padding_height, const int padding_width, + T* const filter_grad_data) { + int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int w_out = index % output_width; + const int h_out = (index / output_width) % output_height; + const int c_out = (index / output_width / output_height) % output_channels; + const int batch = (index / output_width / output_height / output_channels); + const int c_in = c_out / filter_multiplier; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = + -padding_height + h_out * stride_height + filter_height; + const int w_in_end = -padding_width + w_out * stride_width + filter_width; + const int in_offset = + (batch * input_channels + c_in) * input_height * input_width; + + T* addr_offset = filter_grad_data + c_out * filter_height * filter_width; + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + + for (int h_in = h_start; h_in < h_end; h_in++) { + for (int w_in = w_start; w_in < w_end; w_in++) { + const int offset = in_offset + h_in * input_width + w_in; + const T diff_temp = output_grad_data[index] * input_data[offset]; + T* addr = addr_offset + (h_in - h_in_start) * filter_width + + (w_in - w_in_start); + paddle::platform::CudaAtomicAdd(addr, diff_temp); + } + } + } +} + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class DepthwiseConvFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& filter, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + const int ksize_height = filter.dims()[2]; + const int ksize_width = filter.dims()[3]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* filter_data = filter.data(); + T* output_data = output->mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelDepthwiseConv<<>>( + nthreads, input_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + output_channels / input_channels, ksize_height, ksize_width, + stride_height, stride_width, padding_height, padding_width, + output_data); + } +}; + +template +class DepthwiseConvInputGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& filter, + const framework::Tensor& output_grad, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int ksize_height = filter.dims()[2]; + const int ksize_width = filter.dims()[3]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* filter_data = filter.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelDepthwiseConvInputGrad<<>>( + nthreads, output_grad_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + output_channels / input_channels, ksize_height, ksize_width, + stride_height, stride_width, padding_height, padding_width, + input_grad_data); + } +}; + +template +class DepthwiseConvFilterGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output_grad, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* filter_grad) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int ksize_height = filter_grad->dims()[2]; + const int ksize_width = filter_grad->dims()[3]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* output_grad_data = output_grad.data(); + T* filter_grad_data = filter_grad->mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelDepthwiseConvFilterGrad<<>>( + nthreads, output_grad_data, input_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + output_channels / input_channels, ksize_height, ksize_width, + stride_height, stride_width, padding_height, padding_width, + filter_grad_data); + } +}; + +template class DepthwiseConvFunctor; +template class DepthwiseConvFunctor; + +template class DepthwiseConvInputGradFunctor; +template class DepthwiseConvInputGradFunctor; + +template class DepthwiseConvFilterGradFunctor; +template class DepthwiseConvFilterGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/depthwise_conv.h b/paddle/operators/math/depthwise_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..4708920bb42db90d84fda0c6a1039991cb79e80d --- /dev/null +++ b/paddle/operators/math/depthwise_conv.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * \brief Compute the depthwise convolution which include + * forward process and backpropagation process + */ +template +class DepthwiseConvFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& filter, + const std::vector& strides, + const std::vector& paddings, framework::Tensor* output); +}; + +template +class DepthwiseConvInputGradFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& filter, + const framework::Tensor& output_grad, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* input_grad); +}; + +template +class DepthwiseConvFilterGradFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const framework::Tensor& output_grad, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* filter_grad); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index cb8a4815db589faa7d6665ac1b3e27e0252b107f..a79479f469a0c489edf2676bc5d07066bb480664 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -1231,10 +1231,17 @@ def conv2d(input, """ if stride is None: stride = [1, 1] - helper = LayerHelper('conv2d', **locals()) - dtype = helper.input_dtype() num_channels = input.shape[1] + + l_type = 'conv2d' + if (num_channels == groups and num_filters % num_channels == 0 and + not use_cudnn): + l_type = 'depthwise_conv2d' + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + if groups is None: num_filter_channels = num_channels else: @@ -1267,7 +1274,7 @@ def conv2d(input, pre_bias = helper.create_tmp_variable(dtype) helper.append_op( - type='conv2d', + type=l_type, inputs={ 'Input': input, 'Filter': filter_param, diff --git a/python/paddle/v2/fluid/tests/test_conv2d_op.py b/python/paddle/v2/fluid/tests/test_conv2d_op.py index 24de74d730eedbccb4837598bd6d2eb92da59e0d..7512ea333e37d5f4f0102531d8d13f8c2a744b8d 100644 --- a/python/paddle/v2/fluid/tests/test_conv2d_op.py +++ b/python/paddle/v2/fluid/tests/test_conv2d_op.py @@ -241,6 +241,30 @@ class TestCUDNNWith1x1(TestWith1x1): self.op_type = "conv2d" +class TestDepthwiseConv(TestConv2dOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + +class TestDepthwiseConv2(TestConv2dOp): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.groups = 3 + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] / self.groups + self.filter_size = [6, f_c, 3, 3] + self.op_type = "depthwise_conv2d" + + # cudnn v5 does not support dilation conv. # class TestCUDNNWithDilation(TestWithDilation): # def init_op_type(self):