DepthwiseConvOpGpu.cu 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
16
#include "DepthwiseConvOp.h"
17
#include "GemmFunctor.h"
18
#include "paddle/math/BaseMatrix.h"
19 20

namespace paddle {
21

22
// CUDA kernel to compute the depthwise convolution forward pass
23
template <class T>
X
xzl 已提交
24
__global__
25 26 27
void ConvolutionDepthwiseForward(const int nthreads,
    const T* const inputData, const T* const filterData,
    const int batchSize, const int outputChannels, const int outputHeight,
X
xzl 已提交
28 29 30 31
    const int outputWidth, const int inputChannels, const int inputHeight,
    const int inputWidth, const int filterMultiplier, const int filterHeight,
    const int filterWidth, const int strideH, const int strideW,
    const int paddingH, const int paddingW, T* const outputData) {
32 33 34

  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
X
xzl 已提交
35 36

  if (index < nthreads) {
37 38 39 40 41
    const int batch = index / outputChannels / outputHeight / outputWidth;
    const int c_out = (index / outputHeight / outputWidth) % outputChannels;
    const int h_out = (index / outputWidth) % outputHeight;
    const int w_out = index % outputWidth;

X
xzl 已提交
42
    const int c_in = c_out / filterMultiplier;
43
    const T* weight = filterData + c_out * filterHeight * filterWidth;
44
    T value = 0;
45 46 47 48
    const int h_in_start = -paddingH + h_out * strideH;
    const int w_in_start = -paddingW + w_out * strideW;
    const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1;
    const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1;
X
xzl 已提交
49 50
    if ((h_in_start >= 0) && (h_in_end < inputHeight)
       && (w_in_start >= 0) && (w_in_end < inputWidth)) {
51 52
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
53 54
                const int h_in = -paddingH + h_out * strideH + kh;
                const int w_in = -paddingW + w_out * strideW + kw;
X
xzl 已提交
55 56
                const int offset = ((batch * inputChannels + c_in)
                    * inputHeight + h_in) * inputWidth + w_in;
57 58
                value += (*weight) * inputData[offset];
                ++weight;
X
xzl 已提交
59 60 61
            }
        }
    } else {
62 63
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
64 65
                const int h_in = -paddingH + h_out * strideH + kh;
                const int w_in = -paddingW + w_out * strideW + kw;
66 67
                if ((h_in >= 0) && (h_in < inputHeight)
                   && (w_in >= 0) && (w_in < inputWidth)) {
X
xzl 已提交
68 69
                    const int offset = ((batch * inputChannels + c_in)
                        * inputHeight + h_in) * inputWidth + w_in;
70 71 72 73 74
                    value += (*weight) * inputData[offset];
                }
                ++weight;
            }
       }
X
xzl 已提交
75
    }
76
    outputData[index] = value;
77 78 79
  }
}

80
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
81
template <class T>
82 83
__global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
84
    const T* const top_diff, const T* const weight_data,
85
    const int num, const int outputChannels, const int outputHeight,
X
xzl 已提交
86 87 88 89
    const int outputWidth, const int inputChannels, const int inputHeight,
    const int inputWidth, const int filterMultiplier, const int filterHeight,
    const int filterWidth, const int strideH, const int strideW,
    const int paddingH, const int paddingW, T* const bottom_diff) {
90 91
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
X
xzl 已提交
92
  if (index < nthreads) {
93 94 95 96
    const int batch = index / inputChannels / inputHeight / inputWidth;
    const int c_in = (index / inputHeight / inputWidth) % inputChannels;
    const int h_in = (index / inputWidth) % inputHeight;
    const int w_in = index % inputWidth;
97

X
xzl 已提交
98
    const int c_out_start = c_in * filterMultiplier;
99 100 101 102 103 104 105 106 107 108

    int h_out_start = (h_in - filterHeight + paddingH + strideH)/strideH;
    h_out_start = 0 > h_out_start ? 0 : h_out_start;
    int h_out_end = (h_in + paddingH)/strideH;
    h_out_end = outputHeight - 1 < h_out_end? outputHeight - 1 : h_out_end;
    int w_out_start = (w_in - filterWidth + paddingW + strideW)/strideW;
    w_out_start = 0 > w_out_start ? 0 : w_out_start;
    int w_out_end = (w_in + paddingW)/strideW;
    w_out_end = outputWidth - 1 < w_out_end? outputWidth - 1 : w_out_end;

109
    T value = 0;
110

X
xzl 已提交
111 112
    for (int c_out = c_out_start;
         c_out < c_out_start + filterMultiplier; c_out ++) {
113 114 115 116 117 118 119 120 121 122
        for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) {
            const int filter_h = h_in + paddingH - h_out * strideH;
            for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) {
                const int filter_w = w_in + paddingW - w_out * strideW;
                const int filter_offset = c_out * filterHeight * filterWidth
                    + filter_h * filterWidth + filter_w;
                const int top_diff_offset = ((batch * outputChannels + c_out) *
                    outputHeight + h_out)* outputWidth + w_out;
                value += top_diff[top_diff_offset] * weight_data[filter_offset];
            }
123 124 125
        }
    }
    bottom_diff[index] += value;
126
   }
127 128
}

129
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
130
template <class T>
131 132 133 134
__global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
    const T* const top_diff, const T* const inputData,
    const int num, const int outputChannels, const int outputHeight,
X
xzl 已提交
135 136 137 138
    const int outputWidth, const int inputChannels, const int inputHeight,
    const int inputWidth, const int filterMultiplier, const int filterHeight,
    const int filterWidth, const int strideH, const int strideW,
    const int paddingH, const int paddingW, T* const buffer_data) {
139 140 141
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
142 143
    const int h_out = (index / outputWidth) % outputHeight;
    const int w_out = index % outputWidth;
144 145 146
    const int kh = (index / filterWidth / outputHeight / outputWidth)
          % filterHeight;
    const int kw = (index / outputHeight / outputWidth) % filterWidth;
147 148
    const int h_in = -paddingH + h_out * strideH + kh;
    const int w_in = -paddingW + w_out * strideW + kw;
149 150
    if ((h_in >= 0) && (h_in < inputHeight)
          && (w_in >= 0) && (w_in < inputWidth)) {
X
xzl 已提交
151 152 153
      const int c_out = index /
            (filterHeight * filterWidth * outputHeight * outputWidth);
      const int c_in = c_out / filterMultiplier;
154
      const int batch = num_i;
X
xzl 已提交
155 156 157 158
      const int top_offset = ((batch * outputChannels + c_out) *
            outputHeight + h_out) * outputWidth + w_out;
      const int bottom_offset = ((batch * inputChannels + c_in)
            * inputHeight + h_in) * inputWidth + w_in;
159
      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
160 161 162 163 164 165 166 167 168
    } else {
      buffer_data[index] = 0;
    }
  }
}

template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public:
X
xzl 已提交
169
  void operator()(const T* inputData,
170 171 172 173 174
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
X
xzl 已提交
175
            int inputChannels,
176 177
            int inputHeight,
            int inputWidth,
178
            int filterMultiplier,
179 180 181 182 183 184 185
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
            T* outputData){
186 187
    int outputSize = batchSize * outputChannels * outputHeight * outputWidth;

188 189 190 191 192
    size_t blocks = (outputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
193

194
    ConvolutionDepthwiseForward<T>
195
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
X
xzl 已提交
196 197
            outputSize,
            inputData,
198 199 200 201 202
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
X
xzl 已提交
203
            inputChannels,
204 205
            inputHeight,
            inputWidth,
206
            filterMultiplier,
207 208 209 210 211 212 213 214 215 216 217 218 219
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            outputData);
    }
};

template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public:
220
  void operator()(const T* outputGrad,
221 222 223 224 225
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
226
            int inputChannels,
227 228
            int inputHeight,
            int inputWidth,
229
            int filterMultiplier,
230 231 232 233 234 235
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
236
            T* inputGrad){
237
    int inputSize = batchSize * inputChannels * inputHeight * inputWidth;
238 239 240 241 242 243 244

    size_t blocks = (inputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

245

246
    ConvolutionDepthwiseInputBackward<T>
247 248 249 250 251 252 253 254 255
          // NOLINT_NEXT_LINE(whitespace/operators)
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            inputSize,
            outputGrad,
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
X
xzl 已提交
256
            inputChannels,
257 258
            inputHeight,
            inputWidth,
259
            filterMultiplier,
260 261 262 263 264 265 266 267 268 269 270 271 272
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            inputGrad);
    }
};

template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public:
273
  void operator()(const T* outputGrad,
274 275 276 277 278
                const T* inputData,
                int batchSize,
                int outputChannels,
                int outputHeight,
                int outputWidth,
279
                int inputChannels,
280 281
                int inputHeight,
                int inputWidth,
282
                int filterMultiplier,
283 284 285 286 287 288 289 290
                int filterHeight,
                int filterWidth,
                int strideH,
                int strideW,
                int paddingH,
                int paddingW,
                T* colData,
                T* filterGrad){
X
xzl 已提交
291 292
        int colDataSize = outputChannels * filterHeight * filterWidth
            * outputHeight * outputWidth;
293

294 295 296 297 298
        size_t blocks = (colDataSize + 1024 -1) / 1024;
        size_t blockX = 512;
        size_t blockY = (blocks+512-1)/512;
        dim3 threads(1024, 1);
        dim3 grid(blockX, blockY);
X
xzl 已提交
299 300
        BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth,
            1, filterGrad, false, true);
301

X
xzl 已提交
302
        for (int i = 0; i < batchSize; i++) {
303 304
            ConvolutionDepthwiseFilterBackward<T>
                <<< grid, threads, 0, STREAM_DEFAULT >>>(
305 306 307 308 309 310 311 312
                    i,
                    colDataSize,
                    outputGrad,
                    inputData,
                    batchSize,
                    outputChannels,
                    outputHeight,
                    outputWidth,
X
xzl 已提交
313
                    inputChannels,
314 315
                    inputHeight,
                    inputWidth,
X
xzl 已提交
316
                    filterMultiplier,
317 318 319 320 321 322
                    filterHeight,
                    filterWidth,
                    strideH,
                    strideW,
                    paddingH,
                    paddingW,
X
xzl 已提交
323
                    colData);
324
            int K = outputHeight * outputWidth;
325
            int M = colDataSize / K;
326 327

            BaseMatrix colMatrix(M, K, colData, false, true);
X
xzl 已提交
328 329
            filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0);
        }
330 331 332
    }
};

333
#ifdef PADDLE_TYPE_DOUBLE
334 335 336
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
X
xzl 已提交
337
#else
338 339 340
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
341
#endif
342 343

}  // namespace paddle