DepthwiseConvOpGpu.cu 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "DepthwiseConvOp.h"
16
#include "GemmFunctor.h"
17
#include "paddle/math/BaseMatrix.h"
18 19

namespace paddle {
20

21
// CUDA kernel to compute the depthwise convolution forward pass
22
template <class T>
X
xzl 已提交
23
__global__
24 25 26
void ConvolutionDepthwiseForward(const int nthreads,
    const T* const inputData, const T* const filterData,
    const int batchSize, const int outputChannels, const int outputHeight,
X
xzl 已提交
27 28 29 30
    const int outputWidth, const int inputChannels, const int inputHeight,
    const int inputWidth, const int filterMultiplier, const int filterHeight,
    const int filterWidth, const int strideH, const int strideW,
    const int paddingH, const int paddingW, T* const outputData) {
31 32 33

  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
X
xzl 已提交
34 35

  if (index < nthreads) {
36 37 38 39 40
    const int batch = index / outputChannels / outputHeight / outputWidth;
    const int c_out = (index / outputHeight / outputWidth) % outputChannels;
    const int h_out = (index / outputWidth) % outputHeight;
    const int w_out = index % outputWidth;

X
xzl 已提交
41
    const int c_in = c_out / filterMultiplier;
42
    const T* weight = filterData + c_out * filterHeight * filterWidth;
43
    T value = 0;
44 45 46 47
    const int h_in_start = -paddingH + h_out * strideH;
    const int w_in_start = -paddingW + w_out * strideW;
    const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1;
    const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1;
X
xzl 已提交
48 49
    if ((h_in_start >= 0) && (h_in_end < inputHeight)
       && (w_in_start >= 0) && (w_in_end < inputWidth)) {
50 51
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
52 53
                const int h_in = -paddingH + h_out * strideH + kh;
                const int w_in = -paddingW + w_out * strideW + kw;
X
xzl 已提交
54 55
                const int offset = ((batch * inputChannels + c_in)
                    * inputHeight + h_in) * inputWidth + w_in;
56 57
                value += (*weight) * inputData[offset];
                ++weight;
X
xzl 已提交
58 59 60
            }
        }
    } else {
61 62
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
63 64
                const int h_in = -paddingH + h_out * strideH + kh;
                const int w_in = -paddingW + w_out * strideW + kw;
65 66
                if ((h_in >= 0) && (h_in < inputHeight)
                   && (w_in >= 0) && (w_in < inputWidth)) {
X
xzl 已提交
67 68
                    const int offset = ((batch * inputChannels + c_in)
                        * inputHeight + h_in) * inputWidth + w_in;
69 70 71 72 73
                    value += (*weight) * inputData[offset];
                }
                ++weight;
            }
       }
X
xzl 已提交
74
    }
75
    outputData[index] = value;
76 77 78
  }
}

79
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
80
template <class T>
81 82
__global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
83
    const T* const top_diff, const T* const weight_data,
84
    const int num, const int outputChannels, const int outputHeight,
X
xzl 已提交
85 86 87 88
    const int outputWidth, const int inputChannels, const int inputHeight,
    const int inputWidth, const int filterMultiplier, const int filterHeight,
    const int filterWidth, const int strideH, const int strideW,
    const int paddingH, const int paddingW, T* const bottom_diff) {
89 90
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
X
xzl 已提交
91
  if (index < nthreads) {
92 93 94 95
    const int batch = index / inputChannels / inputHeight / inputWidth;
    const int c_in = (index / inputHeight / inputWidth) % inputChannels;
    const int h_in = (index / inputWidth) % inputHeight;
    const int w_in = index % inputWidth;
96

X
xzl 已提交
97
    const int c_out_start = c_in * filterMultiplier;
98 99 100 101 102 103 104 105 106 107

    int h_out_start = (h_in - filterHeight + paddingH + strideH)/strideH;
    h_out_start = 0 > h_out_start ? 0 : h_out_start;
    int h_out_end = (h_in + paddingH)/strideH;
    h_out_end = outputHeight - 1 < h_out_end? outputHeight - 1 : h_out_end;
    int w_out_start = (w_in - filterWidth + paddingW + strideW)/strideW;
    w_out_start = 0 > w_out_start ? 0 : w_out_start;
    int w_out_end = (w_in + paddingW)/strideW;
    w_out_end = outputWidth - 1 < w_out_end? outputWidth - 1 : w_out_end;

108
    T value = 0;
109

X
xzl 已提交
110 111
    for (int c_out = c_out_start;
         c_out < c_out_start + filterMultiplier; c_out ++) {
112 113 114 115 116 117 118 119 120 121
        for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) {
            const int filter_h = h_in + paddingH - h_out * strideH;
            for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) {
                const int filter_w = w_in + paddingW - w_out * strideW;
                const int filter_offset = c_out * filterHeight * filterWidth
                    + filter_h * filterWidth + filter_w;
                const int top_diff_offset = ((batch * outputChannels + c_out) *
                    outputHeight + h_out)* outputWidth + w_out;
                value += top_diff[top_diff_offset] * weight_data[filter_offset];
            }
122 123 124
        }
    }
    bottom_diff[index] += value;
125
   }
126 127
}

128
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
129
template <class T>
130 131 132 133
__global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
    const T* const top_diff, const T* const inputData,
    const int num, const int outputChannels, const int outputHeight,
X
xzl 已提交
134 135 136 137
    const int outputWidth, const int inputChannels, const int inputHeight,
    const int inputWidth, const int filterMultiplier, const int filterHeight,
    const int filterWidth, const int strideH, const int strideW,
    const int paddingH, const int paddingW, T* const buffer_data) {
138 139 140
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
141 142
    const int h_out = (index / outputWidth) % outputHeight;
    const int w_out = index % outputWidth;
143 144 145
    const int kh = (index / filterWidth / outputHeight / outputWidth)
          % filterHeight;
    const int kw = (index / outputHeight / outputWidth) % filterWidth;
146 147
    const int h_in = -paddingH + h_out * strideH + kh;
    const int w_in = -paddingW + w_out * strideW + kw;
148 149
    if ((h_in >= 0) && (h_in < inputHeight)
          && (w_in >= 0) && (w_in < inputWidth)) {
X
xzl 已提交
150 151 152
      const int c_out = index /
            (filterHeight * filterWidth * outputHeight * outputWidth);
      const int c_in = c_out / filterMultiplier;
153
      const int batch = num_i;
X
xzl 已提交
154 155 156 157
      const int top_offset = ((batch * outputChannels + c_out) *
            outputHeight + h_out) * outputWidth + w_out;
      const int bottom_offset = ((batch * inputChannels + c_in)
            * inputHeight + h_in) * inputWidth + w_in;
158
      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
159 160 161 162 163 164 165 166 167
    } else {
      buffer_data[index] = 0;
    }
  }
}

template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public:
X
xzl 已提交
168
  void operator()(const T* inputData,
169 170 171 172 173
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
X
xzl 已提交
174
            int inputChannels,
175 176
            int inputHeight,
            int inputWidth,
177
            int filterMultiplier,
178 179 180 181 182 183 184
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
            T* outputData){
185 186
    int outputSize = batchSize * outputChannels * outputHeight * outputWidth;

187 188 189 190 191
    size_t blocks = (outputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
192

193
    ConvolutionDepthwiseForward<T>
194
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
X
xzl 已提交
195 196
            outputSize,
            inputData,
197 198 199 200 201
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
X
xzl 已提交
202
            inputChannels,
203 204
            inputHeight,
            inputWidth,
205
            filterMultiplier,
206 207 208 209 210 211 212 213 214 215 216 217 218
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            outputData);
    }
};

template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public:
219
  void operator()(const T* outputGrad,
220 221 222 223 224
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
225
            int inputChannels,
226 227
            int inputHeight,
            int inputWidth,
228
            int filterMultiplier,
229 230 231 232 233 234
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
235
            T* inputGrad){
236
    int inputSize = batchSize * inputChannels * inputHeight * inputWidth;
237 238 239 240 241 242 243

    size_t blocks = (inputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

244

245
    ConvolutionDepthwiseInputBackward<T>
246 247 248 249 250 251 252 253 254
          // NOLINT_NEXT_LINE(whitespace/operators)
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            inputSize,
            outputGrad,
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
X
xzl 已提交
255
            inputChannels,
256 257
            inputHeight,
            inputWidth,
258
            filterMultiplier,
259 260 261 262 263 264 265 266 267 268 269 270 271
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            inputGrad);
    }
};

template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public:
272
  void operator()(const T* outputGrad,
273 274 275 276 277
                const T* inputData,
                int batchSize,
                int outputChannels,
                int outputHeight,
                int outputWidth,
278
                int inputChannels,
279 280
                int inputHeight,
                int inputWidth,
281
                int filterMultiplier,
282 283 284 285 286 287 288 289
                int filterHeight,
                int filterWidth,
                int strideH,
                int strideW,
                int paddingH,
                int paddingW,
                T* colData,
                T* filterGrad){
X
xzl 已提交
290 291
        int colDataSize = outputChannels * filterHeight * filterWidth
            * outputHeight * outputWidth;
292

293 294 295 296 297
        size_t blocks = (colDataSize + 1024 -1) / 1024;
        size_t blockX = 512;
        size_t blockY = (blocks+512-1)/512;
        dim3 threads(1024, 1);
        dim3 grid(blockX, blockY);
X
xzl 已提交
298 299
        BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth,
            1, filterGrad, false, true);
300

X
xzl 已提交
301
        for (int i = 0; i < batchSize; i++) {
302 303
            ConvolutionDepthwiseFilterBackward<T>
                <<< grid, threads, 0, STREAM_DEFAULT >>>(
304 305 306 307 308 309 310 311
                    i,
                    colDataSize,
                    outputGrad,
                    inputData,
                    batchSize,
                    outputChannels,
                    outputHeight,
                    outputWidth,
X
xzl 已提交
312
                    inputChannels,
313 314
                    inputHeight,
                    inputWidth,
X
xzl 已提交
315
                    filterMultiplier,
316 317 318 319 320 321
                    filterHeight,
                    filterWidth,
                    strideH,
                    strideW,
                    paddingH,
                    paddingW,
X
xzl 已提交
322
                    colData);
323
            int K = outputHeight * outputWidth;
324
            int M = colDataSize / K;
325 326

            BaseMatrix colMatrix(M, K, colData, false, true);
X
xzl 已提交
327 328
            filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0);
        }
329 330 331
    }
};

332
#ifdef PADDLE_TYPE_DOUBLE
333 334 335
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
X
xzl 已提交
336
#else
337 338 339
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
340
#endif
341 342

}  // namespace paddle