DepthwiseConvOp.h 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "TensorType.h"
18 19 20

namespace paddle {

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/**
 *\brief   Depthwise convolution forward. The outputData
 *         of depthwise convolution is same with ExpandConvLayer
 *         when groups equals inputChannels in ExpandConvLayer.
 *
 * \param[in]   inputData         input data.
 * \param[in]   filterData        the Paramters of the depthwise conv layer..
 * \param[in]   batchSize         batch size of input data.
 * \param[in]   outputChannels    channels of outputData.
 * \param[in]   outputHeight      height of outputData.
 * \param[in]   outputWidth       width of outputData.
 * \param[in]   inputHeight       height of inputData.
 * \param[in]   inputWidth        width of inputData..
 * \param[in]   filterHeight      height of filter.
 * \param[in]   filterWidth       widht of filter.
 * \param[in]   strideH           stride size in height direction.
 * \param[in]   strideW           stride size in width direction.
 * \param[in]   paddingH          padding size in height direction.
 * \param[in]   paddingW          padding size in width direction.
 * \param[out]  outputData        outputData.
 *
 */
43 44 45
template <DeviceType Device, class T>
class DepthwiseConvFunctor {
public:
46
  void operator()(const T* inputData,
47 48 49 50 51
                  const T* filterData,
                  int batchSize,
                  int outputChannels,
                  int outputHeight,
                  int outputWidth,
52 53
                  int inputHeight,
                  int intputWidth,
54 55 56 57 58 59 60 61 62
                  int filterHeight,
                  int filterWidth,
                  int strideH,
                  int strideW,
                  int paddingH,
                  int paddingW,
                  T* outputData);
};

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
/**
 *\brief  Functor tot compute the depthwise convolution backprop w.r.t input.
 *
 *
 * \param[in]   outputGradData    the grad data of output.
 * \param[in]   filterData        the Paramters of the depthwise conv layer..
 * \param[in]   batchSize         batch size of input data.
 * \param[in]   outputChannels    channels of outputData.
 * \param[in]   outputHeight      height of outputData.
 * \param[in]   outputWidth       width of outputData.
 * \param[in]   inputChannels     channels of input data.
 * \param[in]   inputHeight       height of inputData.
 * \param[in]   inputWidth        width of inputData..
 * \param[in]   filterHeight      height of filter.
 * \param[in]   filterWidth       widht of filter.
 * \param[in]   strideH           stride size in height direction.
 * \param[in]   strideW           stride size in width direction.
 * \param[in]   paddingH          padding size in height direction.
 * \param[in]   paddingW          padding size in width direction.
 * \param[out]  inputGrad         the grad data of input.
 *
 */
85 86 87
template <DeviceType Device, class T>
class DepthwiseConvGradInputFunctor {
public:
88
  void operator()(const T* outputGrad,
89 90 91 92 93
                  const T* filterData,
                  int batchSize,
                  int outputChannels,
                  int outputHeight,
                  int outputWidth,
94
                  int inputChannels,
95 96 97 98 99 100 101 102 103 104 105
                  int inputHeight,
                  int inputWidth,
                  int filterHeight,
                  int filterWidth,
                  int strideH,
                  int strideW,
                  int paddingH,
                  int paddingW,
                  T* inputGrad);
};

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
/**
 *\brief  Functor tot compute the depthwise convolution backprop w.r.t filter.
 *
 * \param[in]   outputGradData    the grad data of output.
 * \param[in]   inputData         inputData.
 * \param[in]   batchSize         batch size of input data.
 * \param[in]   outputChannels    channels of outputData.
 * \param[in]   outputHeight      height of outputData.
 * \param[in]   outputWidth       width of outputData.
 * \param[in]   inputChannels     channels of input data.
 * \param[in]   inputHeight       height of inputData.
 * \param[in]   inputWidth        width of inputData..
 * \param[in]   filterHeight      height of filter.
 * \param[in]   filterWidth       widht of filter.
 * \param[in]   strideH           stride size in height direction.
 * \param[in]   strideW           stride size in width direction.
 * \param[in]   paddingH          padding size in height direction.
 * \param[in]   paddingW          padding size in width direction.
 * \param[in]   colData           Auxiliary data when calculating filterGrad.
125 126
 * \param[in]   multiplierData    Auxiliary data when calculating filterGrad.
 * \param[out]  filterGrad        the grad data of filter.
127 128
 *
 */
129 130 131
template <DeviceType Device, class T>
class DepthwiseConvGradFilterFunctor {
public:
132
  void operator()(const T* outputGrad,
133 134 135 136 137
                  const T* inputData,
                  int batchSize,
                  int outputChannels,
                  int outputHeight,
                  int outputWidth,
138
                  int inputChannels,
139 140 141 142 143 144 145 146 147 148
                  int inputHeight,
                  int inputWidth,
                  int filterHeight,
                  int filterWidth,
                  int strideH,
                  int strideW,
                  int paddingH,
                  int paddingW,
                  T* colData,
                  T* filterGrad);
149
};
150 151

}  // namespace paddle