diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index dfa2f784610b0dd60340e0ebc6a066437f3715eb..7f32c734791853a8cd0287a80a7955dbd1bd7571 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -31,13 +31,22 @@ public: ConvolutionTest(const std::string& conv1, const std::string& conv2, TestType type, + bool useGroups = true, std::string algo = "auto") { for (size_t batchSize : {1, 32}) { for (size_t inputSize : {7, 14, 54}) { for (size_t filterSize : {1, 3, 5}) { for (size_t inputChannels : {3, 64}) { - for (size_t outputChannels : {3, 64, 128}) { - if (inputChannels < outputChannels) break; + for (size_t outputChannels : {3, 64}) { + if (inputChannels > outputChannels) break; + size_t groups; + if (!useGroups) { + groups = 1; + } else { + if (outputChannels % inputChannels != 0) continue; + groups = inputChannels; + } + for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { if (padding >= filterSize) break; @@ -62,13 +71,24 @@ public: FuncConfig() .set("paddings", paddings) .set("strides", strides) - .set("groups", (size_t)1) + .set("groups", groups) .set("algo", algo)); TensorShape input{ batchSize, inputChannels, inputSize, inputSize}; - TensorShape filter{ - outputChannels, inputChannels, filterSize, filterSize}; + + TensorShape filter; + if (groups > 1) + filter = TensorShape({groups, + outputChannels / groups, + inputChannels / groups, + filterSize, + filterSize}); + else + filter = TensorShape({outputChannels, + inputChannels, + filterSize, + filterSize}); TensorShape output{ batchSize, outputChannels, outputSize, outputSize}; @@ -85,7 +105,8 @@ public: } else if (type == kBackwardFilterTest) { test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); - test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter), + ADD_TO); test.run(); } } @@ -106,6 +127,7 @@ public: ConvolutionTest2(const std::string& conv1, const std::string& conv2, TestType type, + bool useGroups = true, std::string algo = "auto") { for (size_t batchSize : {16}) { for (size_t inputHeight : {7, 31}) { @@ -113,7 +135,15 @@ public: for (size_t filterHeight : {1, 5}) { for (size_t filterWidth : {3, 7}) { for (size_t inputChannels : {7}) { - for (size_t outputChannels : {32}) { + for (size_t outputChannels : {7}) { + size_t groups; + if (!useGroups) { + groups = 1; + } else { + if (outputChannels % inputChannels != 0) continue; + groups = inputChannels; + } + size_t stride = 1; size_t padding = 0; size_t outputHeight = @@ -141,13 +171,24 @@ public: FuncConfig() .set("paddings", paddings) .set("strides", strides) - .set("groups", (size_t)1) + .set("groups", groups) .set("algo", algo)); TensorShape input{ batchSize, inputChannels, inputHeight, inputWidth}; - TensorShape filter{ - outputChannels, inputChannels, filterHeight, filterWidth}; + + TensorShape filter; + if (groups > 1) + filter = TensorShape({groups, + outputChannels / groups, + inputChannels / groups, + filterHeight, + filterWidth}); + else + filter = TensorShape({outputChannels, + inputChannels, + filterHeight, + filterWidth}); TensorShape output{ batchSize, outputChannels, outputHeight, outputWidth}; @@ -164,7 +205,8 @@ public: } else if (type == kBackwardFilterTest) { test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); - test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter), + ADD_TO); test.run(); } } @@ -177,34 +219,88 @@ public: } }; +// ======Start Convolution TEST====== + TEST(Forward, GEMM) { ConvolutionTest test( - "NaiveConv-CPU", "GemmConv-CPU", kForwardTest); + "NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false); ConvolutionTest2 test2( - "NaiveConv-CPU", "GemmConv-CPU", kForwardTest); + "NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false); } #ifndef PADDLE_ONLY_CPU TEST(Forward, GEMM2) { ConvolutionTest test( - "GemmConv-CPU", "GemmConv-GPU", kForwardTest); + "GemmConv-CPU", "GemmConv-GPU", kForwardTest, false); ConvolutionTest2 test2( - "GemmConv-CPU", "GemmConv-GPU", kForwardTest); + "GemmConv-CPU", "GemmConv-GPU", kForwardTest, false); } TEST(BackwardInput, GEMM) { ConvolutionTest test( - "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); + "GemmConvGradInput-CPU", + "GemmConvGradInput-GPU", + kBackwardInputTest, + false); ConvolutionTest2 test2( - "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); + "GemmConvGradInput-CPU", + "GemmConvGradInput-GPU", + kBackwardInputTest, + false); } TEST(BackwardFilter, GEMM) { ConvolutionTest test( - "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); + "GemmConvGradFilter-CPU", + "GemmConvGradFilter-GPU", + kBackwardFilterTest, + false); ConvolutionTest2 test2( - "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); + "GemmConvGradFilter-CPU", + "GemmConvGradFilter-GPU", + kBackwardFilterTest, + false); } #endif +// ======End Convolution TEST====== + +// ======Start DepthwiseConvolution TEST====== + +// TODO(zhaolong) The depthwise convolution cpu test will be added when the cpu +// version of depthwiseConv is implemented. + +#ifndef PADDLE_ONLY_CPU + +TEST(DepthwiseConvForward, GEMM2) { + ConvolutionTest test( + "GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest); + ConvolutionTest2 test2( + "GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest); +} + +TEST(DepthwiseConvBackwardInput, GEMM) { + ConvolutionTest test( + "GemmConvGradInput-CPU", + "DepthwiseConvGradInput-GPU", + kBackwardInputTest); + ConvolutionTest2 test2( + "GemmConvGradInput-CPU", + "DepthwiseConvGradInput-GPU", + kBackwardInputTest); +} + +TEST(DepthwiseConvBackwardFilter, GEMM) { + ConvolutionTest test( + "GemmConvGradFilter-CPU", + "DepthwiseConvGradFilter-GPU", + kBackwardFilterTest); + ConvolutionTest2 test2( + "GemmConvGradFilter-CPU", + "DepthwiseConvGradFilter-GPU", + kBackwardFilterTest); +} + +#endif +// ======End DepthwiseConvolution TEST====== } // namespace paddle diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..490e8d546cbd460217abe95f6291b13fa207faa9 --- /dev/null +++ b/paddle/function/DepthwiseConvOp.cpp @@ -0,0 +1,306 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "DepthwiseConvOp.h" +#include "ConvOp.h" +#include "GemmFunctor.h" + +namespace paddle { + +template +class DepthwiseConvFunctor { +public: + void operator()(const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData) { + // TODO(zhaolong) : cpu implementation of depthwise convolution + } +}; + +template +class DepthwiseConvGradInputFunctor { +public: + void operator()(const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad) {} + // TODO(zhaolong) : cpu implementation of depthwise convolution +}; + +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* filterGrad) {} + // TODO(zhaolong) : cpu implementation of depthwise convolution +}; + +/* + * \brief Forward calculation of depthwise convolution. + */ +template +class DepthwiseConvFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; + CHECK_EQ(inputChannels, groups_); + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + + DepthwiseConvFunctor depthwiseConv; + depthwiseConv(inputData, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputData); + } +}; + +/* + * \brief Backward input calculation of depthwise convolution. + */ +template +class DepthwiseConvGradInputFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + check(inputs, outputs); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; + CHECK_EQ(inputChannels, groups_); + + real* outputGrad = inputs[0].data(); + real* filterData = inputs[1].data(); + real* inputGrad = outputs[0].data(); + + DepthwiseConvGradInputFunctor depthwiseConvGradInput; + depthwiseConvGradInput(outputGrad, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + inputGrad); + } +}; + +/* + * \brief Backward filter calculation of depthwise convolution. + */ +template +class DepthwiseConvGradFilterFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + check(inputs, outputs); + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; + CHECK_EQ(inputChannels, groups_); + + real* outputGrad = inputs[0].data(); + real* inputData = inputs[1].data(); + real* filterGrad = outputs[0].data(); + + int size = outputChannels * filterHeight * filterWidth * outputHeight * + outputWidth; + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + DepthwiseConvGradFilterFunctor depthwiseConvGradFilter; + + depthwiseConvGradFilter(outputGrad, + inputData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + colData, + filterGrad); + } +}; + +REGISTER_TYPED_FUNC(DepthwiseConv, CPU, DepthwiseConvFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradInput, + CPU, + DepthwiseConvGradInputFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradFilter, + CPU, + DepthwiseConvGradFilterFunction); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(DepthwiseConv, GPU, DepthwiseConvFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradInput, + GPU, + DepthwiseConvGradInputFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradFilter, + GPU, + DepthwiseConvGradFilterFunction); +#endif + +} // namespace paddle diff --git a/paddle/function/DepthwiseConvOp.h b/paddle/function/DepthwiseConvOp.h new file mode 100644 index 0000000000000000000000000000000000000000..1bf70e52f34626405b49571e023ac60926713eef --- /dev/null +++ b/paddle/function/DepthwiseConvOp.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "TensorType.h" + +namespace paddle { + +/** + *\brief Depthwise convolution forward. The outputData + * of depthwise convolution is same with ExpandConvLayer + * when groups equals inputChannels in ExpandConvLayer. + * + * \param[in] inputData input data. + * \param[in] filterData the Paramters of the depthwise conv layer.. + * \param[in] batchSize batch size of input data. + * \param[in] outputChannels channels of outputData. + * \param[in] outputHeight height of outputData. + * \param[in] outputWidth width of outputData. + * \param[in] inputChannels channels of inputData. + * \param[in] inputHeight height of inputData. + * \param[in] inputWidth width of inputData.. + * \param[in] filterMultiplier equals to outputChannels/groups_. + * \param[in] filterHeight height of filter. + * \param[in] filterWidth widht of filter. + * \param[in] strideH stride size in height direction. + * \param[in] strideW stride size in width direction. + * \param[in] paddingH padding size in height direction. + * \param[in] paddingW padding size in width direction. + * \param[out] outputData outputData. + * + */ +template +class DepthwiseConvFunctor { +public: + void operator()(const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData); +}; + +/** + *\brief Functor tot compute the depthwise convolution backprop w.r.t input. + * + * + * \param[in] outputGradData the grad data of output. + * \param[in] filterData the Paramters of the depthwise conv layer.. + * \param[in] batchSize batch size of input data. + * \param[in] outputChannels channels of outputData. + * \param[in] outputHeight height of outputData. + * \param[in] outputWidth width of outputData. + * \param[in] inputChannels channels of input data. + * \param[in] inputHeight height of inputData. + * \param[in] inputWidth width of inputData. + * \param[in] filterMultiplier equals to outputChannels/groups_. + * \param[in] filterHeight height of filter. + * \param[in] filterWidth widht of filter. + * \param[in] strideH stride size in height direction. + * \param[in] strideW stride size in width direction. + * \param[in] paddingH padding size in height direction. + * \param[in] paddingW padding size in width direction. + * \param[out] inputGrad the grad data of input. + * + */ +template +class DepthwiseConvGradInputFunctor { +public: + void operator()(const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad); +}; + +/** + *\brief Functor tot compute the depthwise convolution backprop w.r.t filter. + * + * \param[in] outputGradData the grad data of output. + * \param[in] inputData inputData. + * \param[in] batchSize batch size of input data. + * \param[in] outputChannels channels of outputData. + * \param[in] outputHeight height of outputData. + * \param[in] outputWidth width of outputData. + * \param[in] inputChannels channels of input data. + * \param[in] inputHeight height of inputData. + * \param[in] inputWidth width of inputData. + * \param[in] filterMultiplier equals to outputChannels/groups_. + * \param[in] filterHeight height of filter. + * \param[in] filterWidth widht of filter. + * \param[in] strideH stride size in height direction. + * \param[in] strideW stride size in width direction. + * \param[in] paddingH padding size in height direction. + * \param[in] paddingW padding size in width direction. + * \param[in] colData Auxiliary data when calculating filterGrad. + * \param[in] multiplierData Auxiliary data when calculating filterGrad. + * \param[out] filterGrad the grad data of filter. + * + */ +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* filterGrad); +}; + +} // namespace paddle diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..ede0d27aa82e7d71ff5bc33df110fec260e06463 --- /dev/null +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -0,0 +1,342 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "DepthwiseConvOp.h" +#include "GemmFunctor.h" +#include "paddle/math/BaseMatrix.h" + +namespace paddle { + +// CUDA kernel to compute the depthwise convolution forward pass +template +__global__ +void ConvolutionDepthwiseForward(const int nthreads, + const T* const inputData, const T* const filterData, + const int batchSize, const int outputChannels, const int outputHeight, + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const outputData) { + + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + if (index < nthreads) { + const int batch = index / outputChannels / outputHeight / outputWidth; + const int c_out = (index / outputHeight / outputWidth) % outputChannels; + const int h_out = (index / outputWidth) % outputHeight; + const int w_out = index % outputWidth; + + const int c_in = c_out / filterMultiplier; + const T* weight = filterData + c_out * filterHeight * filterWidth; + T value = 0; + const int h_in_start = -paddingH + h_out * strideH; + const int w_in_start = -paddingW + w_out * strideW; + const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1; + const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1; + if ((h_in_start >= 0) && (h_in_end < inputHeight) + && (w_in_start >= 0) && (w_in_end < inputWidth)) { + for (int kh = 0; kh < filterHeight; ++kh) { + for (int kw = 0; kw < filterWidth; ++kw) { + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + const int offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; + value += (*weight) * inputData[offset]; + ++weight; + } + } + } else { + for (int kh = 0; kh < filterHeight; ++kh) { + for (int kw = 0; kw < filterWidth; ++kw) { + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + if ((h_in >= 0) && (h_in < inputHeight) + && (w_in >= 0) && (w_in < inputWidth)) { + const int offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; + value += (*weight) * inputData[offset]; + } + ++weight; + } + } + } + outputData[index] = value; + } +} + +// CUDA kernel to compute the depthwise convolution backprop w.r.t input. +template +__global__ +void ConvolutionDepthwiseInputBackward(const int nthreads, + const T* const top_diff, const T* const weight_data, + const int num, const int outputChannels, const int outputHeight, + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const bottom_diff) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int batch = index / inputChannels / inputHeight / inputWidth; + const int c_in = (index / inputHeight / inputWidth) % inputChannels; + const int h_in = (index / inputWidth) % inputHeight; + const int w_in = index % inputWidth; + + const int c_out_start = c_in * filterMultiplier; + + int h_out_start = (h_in - filterHeight + paddingH + strideH)/strideH; + h_out_start = 0 > h_out_start ? 0 : h_out_start; + int h_out_end = (h_in + paddingH)/strideH; + h_out_end = outputHeight - 1 < h_out_end? outputHeight - 1 : h_out_end; + int w_out_start = (w_in - filterWidth + paddingW + strideW)/strideW; + w_out_start = 0 > w_out_start ? 0 : w_out_start; + int w_out_end = (w_in + paddingW)/strideW; + w_out_end = outputWidth - 1 < w_out_end? outputWidth - 1 : w_out_end; + + T value = 0; + + for (int c_out = c_out_start; + c_out < c_out_start + filterMultiplier; c_out ++) { + for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) { + const int filter_h = h_in + paddingH - h_out * strideH; + for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) { + const int filter_w = w_in + paddingW - w_out * strideW; + const int filter_offset = c_out * filterHeight * filterWidth + + filter_h * filterWidth + filter_w; + const int top_diff_offset = ((batch * outputChannels + c_out) * + outputHeight + h_out)* outputWidth + w_out; + value += top_diff[top_diff_offset] * weight_data[filter_offset]; + } + } + } + bottom_diff[index] += value; + } +} + +// CUDA kernel to compute the depthwise convolution backprop w.r.t filter. +template +__global__ +void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, + const T* const top_diff, const T* const inputData, + const int num, const int outputChannels, const int outputHeight, + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const buffer_data) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int h_out = (index / outputWidth) % outputHeight; + const int w_out = index % outputWidth; + const int kh = (index / filterWidth / outputHeight / outputWidth) + % filterHeight; + const int kw = (index / outputHeight / outputWidth) % filterWidth; + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + if ((h_in >= 0) && (h_in < inputHeight) + && (w_in >= 0) && (w_in < inputWidth)) { + const int c_out = index / + (filterHeight * filterWidth * outputHeight * outputWidth); + const int c_in = c_out / filterMultiplier; + const int batch = num_i; + const int top_offset = ((batch * outputChannels + c_out) * + outputHeight + h_out) * outputWidth + w_out; + const int bottom_offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; + buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset]; + } else { + buffer_data[index] = 0; + } + } +} + +template +class DepthwiseConvFunctor{ +public: + void operator()(const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData){ + int outputSize = batchSize * outputChannels * outputHeight * outputWidth; + + size_t blocks = (outputSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + ConvolutionDepthwiseForward + <<< grid, threads, 0, STREAM_DEFAULT >>>( + outputSize, + inputData, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + outputData); + } +}; + +template +class DepthwiseConvGradInputFunctor{ +public: + void operator()(const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad){ + int inputSize = batchSize * inputChannels * inputHeight * inputWidth; + + size_t blocks = (inputSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + + ConvolutionDepthwiseInputBackward + // NOLINT_NEXT_LINE(whitespace/operators) + <<< grid, threads, 0, STREAM_DEFAULT >>>( + inputSize, + outputGrad, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + inputGrad); + } +}; + +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* filterGrad){ + int colDataSize = outputChannels * filterHeight * filterWidth + * outputHeight * outputWidth; + + size_t blocks = (colDataSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth, + 1, filterGrad, false, true); + + for (int i = 0; i < batchSize; i++) { + ConvolutionDepthwiseFilterBackward + <<< grid, threads, 0, STREAM_DEFAULT >>>( + i, + colDataSize, + outputGrad, + inputData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + colData); + int K = outputHeight * outputWidth; + int M = colDataSize / K; + + BaseMatrix colMatrix(M, K, colData, false, true); + filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); + } + } +}; + +#ifdef PADDLE_TYPE_DOUBLE +template class DepthwiseConvGradInputFunctor; +template class DepthwiseConvFunctor; +template class DepthwiseConvGradFilterFunctor; +#else +template class DepthwiseConvGradInputFunctor; +template class DepthwiseConvFunctor; +template class DepthwiseConvGradFilterFunctor; +#endif + +} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index af79e65a7c09e5a1b55febf1df1e8f5bb61bdcb8..783e02e47cb91e28eb88b079f1e94439d34fa775 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -38,10 +38,25 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, inputShape_.resize(numInputs); filterShape_.resize(numInputs); outputShape_.resize(numInputs); + + std::string convType; + std::string convGradInputType; + std::string convGradFilterType; + for (int i = 0; i < config_.inputs_size(); i++) { std::vector paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector strides = {(size_t)strideY_[i], (size_t)stride_[i]}; + if (useGpu_ && (size_t)groups_[i] == (size_t)channels_[i] && !isDeconv_) { + convType = "DepthwiseConv"; + convGradInputType = "DepthwiseConvGradInput"; + convGradFilterType = "DepthwiseConvGradFilter"; + } else { + convType = "GemmConv"; + convGradInputType = "GemmConvGradInput"; + convGradFilterType = "GemmConvGradFilter"; + } + if (FLAGS_use_nnpack) { CHECK_EQ(isDeconv_, false); createFunction(forward_, @@ -53,21 +68,21 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, .set("algo", std::string("auto"))); } else { createFunction(forward_, - !isDeconv_ ? "GemmConv" : "GemmConvGradInput", + !isDeconv_ ? convType : convGradInputType, FuncConfig() .set("paddings", paddings) .set("strides", strides) .set("groups", (size_t)groups_[i])); createFunction(backward_, - !isDeconv_ ? "GemmConvGradInput" : "GemmConv", + !isDeconv_ ? convGradInputType : convType, FuncConfig() .set("paddings", paddings) .set("strides", strides) .set("groups", (size_t)groups_[i])); createFunction(backward_, - "GemmConvGradFilter", + convGradFilterType, FuncConfig() .set("paddings", paddings) .set("strides", strides) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 9af083468c0f01218117211f9e4931ca0669e96a..0975c3bc9573c6ccb8f0ac98c41586d322d2465e 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -347,6 +347,55 @@ TEST(Layer, CosSimVecMatLayer) { } } +void testDepthwiseConvLayer(const string& type, bool useGpu) { + TestConfig config; + config.biasSize = 32; + config.layerConfig.set_type(type); + config.layerConfig.set_num_filters(32); + config.layerConfig.set_partial_sum(1); + config.layerConfig.set_shared_biases(true); + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 2048, 192}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + conv->set_filter_size(2); + conv->set_filter_size_y(3); + conv->set_channels(16); + conv->set_padding(0); + conv->set_padding_y(1); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_groups(16); + conv->set_filter_channels(conv->channels() / conv->groups()); + conv->set_img_size(16); + conv->set_img_size_y(8); + conv->set_output_x(outputSize(conv->img_size(), + conv->filter_size(), + conv->padding(), + conv->stride(), + /* caffeMode */ true)); + conv->set_output_y(outputSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + /* caffeMode */ true)); + config.layerConfig.set_size(conv->output_x() * conv->output_y() * + config.layerConfig.num_filters()); + + testLayerGrad(config, "depthwise_conv", 100, false, useGpu); + // Use small batch_size and useWeight=true to test biasGrad + testLayerGrad(config, "depthwise_conv", 2, false, useGpu, true, 0.02); +} + +TEST(Layer, depthwiseConvLayer) { + // 'depthwise_conv' is a sepecial case of 'exconv' whose + // groups size equals to the input channels size. + testDepthwiseConvLayer("exconv", /* useGpu= */ false); +#ifndef PADDLE_ONLY_CPU + testDepthwiseConvLayer("exconv", /* useGpu= */ true); +#endif +} + void testConvLayer(const string& type, bool trans, bool useGpu) { TestConfig config; config.biasSize = 16; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index ab81e67579e39a34e3ace18d14434eb86b66fa5b..fc112f1327f5ad5f1bdd04873394b1fa0e761e29 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3219,6 +3219,10 @@ def ParameterHook(type, **kwargs): if sparsity_ratio is not None: hook.sparsity_ratio = sparsity_ratio return hook + elif type == 'dpruning': + hook = ParameterUpdaterHookConfig() + hook.type = type + return hook else: return None