DepthwiseConvOp.h 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "TensorType.h"
18 19 20

namespace paddle {

21 22 23 24 25 26 27 28 29 30 31
/**
 *\brief   Depthwise convolution forward. The outputData
 *         of depthwise convolution is same with ExpandConvLayer
 *         when groups equals inputChannels in ExpandConvLayer.
 *
 * \param[in]   inputData         input data.
 * \param[in]   filterData        the Paramters of the depthwise conv layer..
 * \param[in]   batchSize         batch size of input data.
 * \param[in]   outputChannels    channels of outputData.
 * \param[in]   outputHeight      height of outputData.
 * \param[in]   outputWidth       width of outputData.
32
 * \param[in]   inputChannels     channels of inputData.
33 34
 * \param[in]   inputHeight       height of inputData.
 * \param[in]   inputWidth        width of inputData..
35
 * \param[in]   filterMultiplier  equals to outputChannels/groups_.
36 37 38 39 40 41 42 43 44
 * \param[in]   filterHeight      height of filter.
 * \param[in]   filterWidth       widht of filter.
 * \param[in]   strideH           stride size in height direction.
 * \param[in]   strideW           stride size in width direction.
 * \param[in]   paddingH          padding size in height direction.
 * \param[in]   paddingW          padding size in width direction.
 * \param[out]  outputData        outputData.
 *
 */
45 46 47
template <DeviceType Device, class T>
class DepthwiseConvFunctor {
public:
48
  void operator()(const T* inputData,
49 50 51 52 53
                  const T* filterData,
                  int batchSize,
                  int outputChannels,
                  int outputHeight,
                  int outputWidth,
54
                  int inputChannels,
55
                  int inputHeight,
56
                  int inputWidth,
57
                  int filterMultiplier,
58 59 60 61 62 63 64 65 66
                  int filterHeight,
                  int filterWidth,
                  int strideH,
                  int strideW,
                  int paddingH,
                  int paddingW,
                  T* outputData);
};

67 68 69 70 71 72 73 74 75 76 77 78
/**
 *\brief  Functor tot compute the depthwise convolution backprop w.r.t input.
 *
 *
 * \param[in]   outputGradData    the grad data of output.
 * \param[in]   filterData        the Paramters of the depthwise conv layer..
 * \param[in]   batchSize         batch size of input data.
 * \param[in]   outputChannels    channels of outputData.
 * \param[in]   outputHeight      height of outputData.
 * \param[in]   outputWidth       width of outputData.
 * \param[in]   inputChannels     channels of input data.
 * \param[in]   inputHeight       height of inputData.
79 80
 * \param[in]   inputWidth        width of inputData.
 * \param[in]   filterMultiplier  equals to outputChannels/groups_.
81 82 83 84 85 86 87 88 89
 * \param[in]   filterHeight      height of filter.
 * \param[in]   filterWidth       widht of filter.
 * \param[in]   strideH           stride size in height direction.
 * \param[in]   strideW           stride size in width direction.
 * \param[in]   paddingH          padding size in height direction.
 * \param[in]   paddingW          padding size in width direction.
 * \param[out]  inputGrad         the grad data of input.
 *
 */
90 91 92
template <DeviceType Device, class T>
class DepthwiseConvGradInputFunctor {
public:
93
  void operator()(const T* outputGrad,
94 95 96 97 98
                  const T* filterData,
                  int batchSize,
                  int outputChannels,
                  int outputHeight,
                  int outputWidth,
99
                  int inputChannels,
100 101
                  int inputHeight,
                  int inputWidth,
102
                  int filterMultiplier,
103 104 105 106 107 108 109 110 111
                  int filterHeight,
                  int filterWidth,
                  int strideH,
                  int strideW,
                  int paddingH,
                  int paddingW,
                  T* inputGrad);
};

112 113 114 115 116 117 118 119 120 121 122
/**
 *\brief  Functor tot compute the depthwise convolution backprop w.r.t filter.
 *
 * \param[in]   outputGradData    the grad data of output.
 * \param[in]   inputData         inputData.
 * \param[in]   batchSize         batch size of input data.
 * \param[in]   outputChannels    channels of outputData.
 * \param[in]   outputHeight      height of outputData.
 * \param[in]   outputWidth       width of outputData.
 * \param[in]   inputChannels     channels of input data.
 * \param[in]   inputHeight       height of inputData.
123 124
 * \param[in]   inputWidth        width of inputData.
 * \param[in]   filterMultiplier  equals to outputChannels/groups_.
125 126 127 128 129 130 131
 * \param[in]   filterHeight      height of filter.
 * \param[in]   filterWidth       widht of filter.
 * \param[in]   strideH           stride size in height direction.
 * \param[in]   strideW           stride size in width direction.
 * \param[in]   paddingH          padding size in height direction.
 * \param[in]   paddingW          padding size in width direction.
 * \param[in]   colData           Auxiliary data when calculating filterGrad.
132 133
 * \param[in]   multiplierData    Auxiliary data when calculating filterGrad.
 * \param[out]  filterGrad        the grad data of filter.
134 135
 *
 */
136 137 138
template <DeviceType Device, class T>
class DepthwiseConvGradFilterFunctor {
public:
139
  void operator()(const T* outputGrad,
140 141 142 143 144
                  const T* inputData,
                  int batchSize,
                  int outputChannels,
                  int outputHeight,
                  int outputWidth,
145
                  int inputChannels,
146 147
                  int inputHeight,
                  int inputWidth,
148
                  int filterMultiplier,
149 150 151 152 153 154 155 156
                  int filterHeight,
                  int filterWidth,
                  int strideH,
                  int strideW,
                  int paddingH,
                  int paddingW,
                  T* colData,
                  T* filterGrad);
157
};
158 159

}  // namespace paddle