diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad332d2931bab144591f10190ec1c07cb6d7940c --- /dev/null +++ b/paddle/function/DepthwiseConvOp.cpp @@ -0,0 +1,308 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "DepthwiseConvOp.h" +#include "GemmFunctor.h" +#include "paddle/math/MemoryHandle.h" + +namespace paddle { + +/* + * imData = [input_channels, input_height, input_width] + * colData = [input_channels, filter_height, filter_width, + * output_height, output_width] + */ +template +class DepthwiseConvFunctor { +public: + void operator()(int outputSize, + const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData) { + // NO_IMPLEMENTATION + } +}; + +template +class DepthwiseConvGradInputFunctor { +public: + void operator()(int inputSize, + const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad) {} +}; + +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(int num_i, + int colDataSize, + const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* multiplierData, + T* filterGrad) {} +}; + +/* + * \brief Forward calculation of convolution. + */ +template +class DepthwiseConvFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + size_t batchSize = input[0]; + // size_t inputChannels = input[1]; + // size_t inputHeight = input[2]; + // size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + size_t outputSize = batchSize * outputChannels * outputHeight * outputWidth; + + DepthwiseConvFunctor depthwiseConv; + depthwiseConv(outputSize, + inputData, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputData); + } +}; + +/* + * \brief Backward input calculation of convolution. + */ +template +class DepthwiseConvGradInputFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + // Since the implementation of Col2ImFunctor is ADD_TO, + // this function only supports ADD_TO mode. + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* outputGrad = inputs[0].data(); + real* filterData = inputs[1].data(); + real* inputGrad = outputs[0].data(); + + size_t inputSize = batchSize * inputChannels * inputHeight * inputWidth; + + DepthwiseConvGradInputFunctor depthwiseConvGradInput; + depthwiseConvGradInput(inputSize, + outputGrad, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + inputGrad); + } +}; + +/* + * \brief Backward filter calculation of convolution. + */ +template +class DepthwiseConvGradFilterFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + // const TensorShape& multiplier = inputs[2].shape(); + const TensorShape& filter = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* outputGrad = inputs[0].data(); + real* inputData = inputs[1].data(); + real* multiplierData = inputs[2].data(); + real* filterGrad = outputs[0].data(); + + size_t size = + inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; + + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + DepthwiseConvGradFilterFunctor depthwiseConvGradFilter; + + for (size_t i = 0; i < batchSize; i++) { + depthwiseConvGradFilter(i, + size, + outputGrad, + inputData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + colData, + multiplierData, + filterGrad); + } + } +}; + +REGISTER_TYPED_FUNC(DepthwiseConv, CPU, DepthwiseConvFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradInput, + CPU, + DepthwiseConvGradInputFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradFilter, + CPU, + DepthwiseConvGradFilterFunction); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(DepthwiseConv, GPU, DepthwiseConvFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradInput, + GPU, + DepthwiseConvGradInputFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradFilter, + GPU, + DepthwiseConvGradFilterFunction); +#endif + +} // namespace paddle diff --git a/paddle/function/DepthwiseConvOp.h b/paddle/function/DepthwiseConvOp.h new file mode 100644 index 0000000000000000000000000000000000000000..8af1db974de520491eb75762a75f7c63f93c73f1 --- /dev/null +++ b/paddle/function/DepthwiseConvOp.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "ConvOp.h" + +namespace paddle { + +/* + * imData = [input_channels, input_height, input_width] + * colData = [input_channels, filter_height, filter_width, + * output_height, output_width] + */ +template +class DepthwiseConvFunctor { +public: + void operator()(int outputSize, + const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData); +}; + +template +class DepthwiseConvGradInputFunctor { +public: + void operator()(int inputSize, + const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad); +}; + +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(int num_i, + int colDataSize, + const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* multiplierData, + T* filterGrad); + +}; // namespace paddle + +} // namespace paddle diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..1b2d5d99ed2c557a1a41f916aa6e546a4bad8a43 --- /dev/null +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -0,0 +1,295 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ConvOp.h" +#include "DepthwiseConvOp.h" + +namespace paddle { +template +__global__ void ConvolutionDepthwiseWeightForward(const int nthreads, + const T* const bottom_data, const T* const weight_data, + const int num, const int channels, const int top_height, + const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, T* const top_data) { + + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + if(index < nthreads) { + const int n = index / channels / top_height / top_width; + const int c = (index / top_height / top_width) % channels; + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const T* weight = weight_data + c * kernel_h * kernel_w; + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + if ((h_in >= 0) && (h_in < bottom_height) + && (w_in >= 0) && (w_in < bottom_width)) { + const int offset = ((n * channels + c) * bottom_height + h_in) + * bottom_width + w_in; + value += (*weight) * bottom_data[offset]; + } + ++weight; + } + } + top_data[index] = value; + } +} + +template +__global__ void ConvolutionDepthwiseBottomBackward(const int nthreads, + const T* const top_diff, const T* const weight_data, + const int num, const int channels, const int top_height, + const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, T* const bottom_diff) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if(index < nthreads) { + const int n = index / channels / bottom_height / bottom_width; + const int c = (index / bottom_height / bottom_width) % channels; + const int h = (index / bottom_width) % bottom_height; + const int w = index % bottom_width; + const T* weight = weight_data + c * kernel_h * kernel_w; + T value = 0; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + const int h_out_s = h + pad_h - kh * dilation_h; + const int w_out_s = w + pad_w - kw * dilation_w; + if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { + const int h_out = h_out_s / stride_h; + const int w_out = w_out_s / stride_w; + //it affect the effectives + if ((h_out >= 0) && (h_out < top_height) + && (w_out >= 0) && (w_out < top_width)) { + const int offset = ((n * channels + c) * top_height + h_out) + * top_width + w_out; + value += (*weight) * top_diff[offset]; + } + } + ++weight; + } + } + bottom_diff[index] += value; + } +} + +template +__global__ void ConvolutionDepthwiseWeightBackward(const int num_i, const int nthreads, + const T* const top_diff, const T* const bottom_data, + const int num, const int channels, const int top_height, + const int top_width, const int bottom_height, const int bottom_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, T* const buffer_data) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int h = (index / top_width) % top_height; + const int w = index % top_width; + const int kh = (index / kernel_w / top_height / top_width) + % kernel_h; + const int kw = (index / top_height / top_width) % kernel_w; + const int h_in = -pad_h + h * stride_h + kh * dilation_h; + const int w_in = -pad_w + w * stride_w + kw * dilation_w; + if ((h_in >= 0) && (h_in < bottom_height) + && (w_in >= 0) && (w_in < bottom_width)) { + const int c = index / kernel_h / kernel_w / top_height / top_width; + const int n = num_i; + const int top_offset = ((n * channels + c) * top_height + h) + * top_width + w; + const int bottom_offset = ((n * channels + c) * bottom_height + h_in) + * bottom_width + w_in; + buffer_data[index] = top_diff[top_offset] * bottom_data[bottom_offset]; + } else { + buffer_data[index] = 0; + } + } +} + +template +class DepthwiseConvFunctor{ +public: + void operator()(int outputSize, + const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData){ + + size_t blocks = (outputSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + ConvolutionDepthwiseWeightForward + <<< grid, threads, 0, STREAM_DEFAULT >>>( + outputSize, + inputData, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + outputData); + } +}; + +template +class DepthwiseConvGradInputFunctor{ +public: + void operator()(int inputSize, + const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad){ + + size_t blocks = (inputSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + ConvolutionDepthwiseBottomBackward + // NOLINT_NEXT_LINE(whitespace/operators) + <<< grid, threads, 0, STREAM_DEFAULT >>>( + inputSize, + outputGrad, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + inputGrad); + } +}; + +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(int num_i, + int colDataSize, + const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* multiplierData, + T* filterGrad){ + + size_t blocks = (colDataSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + ConvolutionDepthwiseWeightBackward + <<< grid, threads, 0, STREAM_DEFAULT >>>( + i, + size, + outputGrad, + inputData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + colData + ); + GemmFunctor gemm; + int M = size / outputHeight / outputWidth; + int N = 1; + int K = outputHeight * outputWidth; + gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1.0f, + colData, + K, + multiplierData, + N, + 1.0f, + filterGrad, + N); + //gemv + } +}; + +template class DepthwiseConvGradInputFunctor; +template class DepthwiseConvGradInputFunctor; +template class DepthwiseConvFunctor; +template class DepthwiseConvFunctor; +template class DepthwiseConvGradFilterFunctor; +template class DepthwiseConvGradFilterFunctor; + +} // namespace paddle diff --git a/paddle/gserver/layers/DepthwiseConvLayer.cpp b/paddle/gserver/layers/DepthwiseConvLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9df8a9df7cc4e0ec76eb71efb7f9a575510e6b6f --- /dev/null +++ b/paddle/gserver/layers/DepthwiseConvLayer.cpp @@ -0,0 +1,165 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "DepthwiseConvLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +/* + * The calculation of the exconvt(convolution transpose (deconv) operation) + * is a swap of forward and backward of the calculation of exconv. + * */ +REGISTER_LAYER(depthwise_conv, DepthwiseConvLayer); + +bool DepthwiseConvLayer::init(const LayerMap &layerMap, + const ParameterMap ¶meterMap) { + /* Initialize the basic convolutional parent class */ + ExpandConvBaseLayer::init(layerMap, parameterMap); + + size_t numInputs = config_.inputs_size(); + inputShape_.resize(numInputs); + filterShape_.resize(numInputs); + outputShape_.resize(numInputs); + multiplierShape_.resize(numInputs); + weightMultiplier_.resize(numInputs); + + for (int i = 0; i < config_.inputs_size(); i++) { + std::vector paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; + std::vector strides = {(size_t)strideY_[i], (size_t)stride_[i]}; + Matrix::resizeOrCreate(weightMultiplier_[i], + (size_t)outputH_[i] * (size_t)outputW_[i], + (size_t)1, + false, + useGpu_); + weightMultiplier_[i]->one(); + createFunction(forward_, + "DepthwiseConv", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + + createFunction(backward_, + "DepthwiseConvGradInput", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + + createFunction(backward_, + "DepthwiseConvGradFilter", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + } + return true; +} + +// i is the index of input layers +#define BACKWARD_INPUT(i, inputs, outputs) \ + backward_[2 * i]->calc(inputs, outputs) +#define BACKWARD_FILTER(i, inputs, outputs) \ + backward_[2 * i + 1]->calc(inputs, outputs) + +void DepthwiseConvLayer::forward(PassType passType) { + Layer::forward(passType); + + size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + resetOutput(batchSize, getOutputSize()); + + // Calculate the shape of the input, output, and filter. + for (size_t i = 0; i < inputLayers_.size(); ++i) { + inputShape_[i] = TensorShape({(size_t)batchSize, + (size_t)channels_[i], + (size_t)imgSizeH_[i], + (size_t)imgSizeW_[i]}); + multiplierShape_[i] = + TensorShape({(size_t)outputH_[i] * (size_t)outputW_[i], (size_t)1}); + filterShape_[i] = TensorShape({(size_t)groups_[i], + (size_t)numFilters_ / groups_[i], + (size_t)channels_[i] / groups_[i], + (size_t)filterSizeY_[i], + (size_t)filterSize_[i]}); + outputShape_[i] = TensorShape({(size_t)batchSize, + (size_t)numFilters_, + (size_t)outputH_[i], + (size_t)outputW_[i]}); + } + + // Calculate the output value. + for (size_t i = 0; i < inputLayers_.size(); ++i) { + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(i), inputShape_[i]); + inputs.addArg(*weights_[i]->getW(), filterShape_[i]); + outputs.addArg( + *getOutputValue(), outputShape_[i], i == 0 ? ASSIGN_TO : ADD_TO); + + forward_[i]->calc(inputs, outputs); + } + + /* add the bias-vector */ + if (biases_.get()) { + if (sharedBiases_) { + addSharedBias(); + } else { + addUnsharedBias(); + } + } + + /* activation */ + forwardActivation(); +} + +void DepthwiseConvLayer::backward(const UpdateCallback &callback) { + backwardActivation(); + + MatrixPtr outGrad = getOutputGrad(); + if (biases_ && biases_->getWGrad()) { + bpropBiases(outGrad); + /* Increasing the number of gradient */ + biases_->getParameterPtr()->incUpdate(callback); + } + + // Calculate the input grad and filter grad. + for (size_t i = 0; i < inputLayers_.size(); ++i) { + if (getInputGrad(i)) { + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outputShape_[i]); + inputs.addArg(*weights_[i]->getW(), filterShape_[i]); + outputs.addArg(*getInputGrad(i), inputShape_[i], ADD_TO); + BACKWARD_INPUT(i, inputs, outputs); + } + + if (weights_[i]->getWGrad()) { + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outputShape_[i]); + inputs.addArg(*getInputValue(i), inputShape_[i]); + inputs.addArg(*weightMultiplier_[i], multiplierShape_[i]); + // weight_multiplier + outputs.addArg(*weights_[i]->getWGrad(), filterShape_[i], ADD_TO); + BACKWARD_FILTER(i, inputs, outputs); + + /* Increasing the number of gradient */ + weights_[i]->getParameterPtr()->incUpdate(callback); + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/DepthwiseConvLayer.h b/paddle/gserver/layers/DepthwiseConvLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..61dd87c12a0b9f9090d26ba3bb9a0d836786a89e --- /dev/null +++ b/paddle/gserver/layers/DepthwiseConvLayer.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "ExpandConvBaseLayer.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +/** + * @brief A subclass of convolution layer. + * This layer expands input and use matrix multiplication to + * calculate convolution operation. + * + * The config file api is img_conv_layer. + */ + +class DepthwiseConvLayer : public ExpandConvBaseLayer { +public: + explicit DepthwiseConvLayer(const LayerConfig& config) + : ExpandConvBaseLayer(config) {} + + ~DepthwiseConvLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback) override; + +protected: + std::vector inputShape_; + std::vector filterShape_; + std::vector outputShape_; + std::vector multiplierShape_; + std::vector weightMultiplier_; +}; + +} // namespace paddle