DepthwiseConvOpGpu.cu 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "DepthwiseConvOp.h"
16
#include "GemmFunctor.h"
17
#include "paddle/math/BaseMatrix.h"
18 19

namespace paddle {
20

21
// CUDA kernel to compute the depthwise convolution forward pass
22
template <class T>
23 24 25 26
__global__ 
void ConvolutionDepthwiseForward(const int nthreads,
    const T* const inputData, const T* const filterData,
    const int batchSize, const int outputChannels, const int outputHeight,
27
    const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth,
28 29 30
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const outputData) {
31 32 33 34 35

  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  
  if(index < nthreads) {
36 37 38 39 40
    const int n = index / outputChannels / outputHeight / outputWidth;
    const int c = (index / outputHeight / outputWidth) % outputChannels;
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const T* weight = filterData + c * filterHeight * filterWidth;
41
    T value = 0;
42 43 44 45
    const int h_in_start = -paddingH + h * strideH;
    const int w_in_start = -paddingW + w * strideW;
    const int h_in_end = -paddingH + h * strideH + filterHeight - 1;
    const int w_in_end = -paddingW + w * strideW + filterWidth - 1;
46 47
    if ((h_in_start >= 0) && (h_in_end < inputHeight) 
		 &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
48 49 50 51 52
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
                const int h_in = -paddingH + h * strideH + kh;
                const int w_in = -paddingW + w * strideW + kw;
                const int offset = ((n * inputChannels + c) * inputHeight + h_in)
53
					* inputWidth + w_in;
54 55
                value += (*weight) * inputData[offset];
                ++weight;
56 57
		  }
		}
58 59 60 61 62 63 64 65 66 67 68 69 70 71
    }else{
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
                const int h_in = -paddingH + h * strideH + kh;
                const int w_in = -paddingW + w * strideW + kw;
                if ((h_in >= 0) && (h_in < inputHeight)
                   && (w_in >= 0) && (w_in < inputWidth)) {
                    const int offset = ((n * outputChannels + c) * inputHeight + h_in)
                        * inputWidth + w_in;
                    value += (*weight) * inputData[offset];
                }
                ++weight;
            }
       }
72 73
	}
    outputData[index] = value;
74 75 76
  }
}

77
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
78
template <class T>
79 80
__global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
81
    const T* const top_diff, const T* const weight_data,
82
    const int num, const int outputChannels, const int outputHeight,
83
    const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth,
84 85 86
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
     T* const bottom_diff) {
87 88 89
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if(index < nthreads) {
90 91
    const int n = index / inputChannels / inputHeight / inputWidth;
    const int c = (index / inputHeight / inputWidth) % inputChannels;
92 93 94
    const int h = (index / inputWidth) % inputHeight;
    const int w = index % inputWidth;
    const T* weight = weight_data + c * filterHeight * filterWidth;
95
    T value = 0;
96 97 98 99 100 101 102
    for (int kh = 0; kh < filterHeight; ++kh) {
      for (int kw = 0; kw < filterWidth; ++kw) {
        const int h_out_s = h + paddingH - kh;
        const int w_out_s = w + paddingW - kw;
        if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
          const int h_out = h_out_s / strideH;
          const int w_out = w_out_s / strideW;
103
	      // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
104 105 106 107
          if ((h_out >= 0) && (h_out < outputHeight)
                && (w_out >= 0) && (w_out < outputWidth)) {
            const int offset = ((n * outputChannels + c) * outputHeight + h_out)
                  * outputWidth + w_out;
108 109 110 111 112 113 114 115 116 117
            value += (*weight) * top_diff[offset];
          }
        }
        ++weight;
      }
    }
    bottom_diff[index] += value;
  }
}

118
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
119
template <class T>
120 121 122 123
__global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
    const T* const top_diff, const T* const inputData,
    const int num, const int outputChannels, const int outputHeight,
124
    const int outputWidth, const int inputChannels, const int inputHeight, const int inputWidth,
125 126 127
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const buffer_data) {
128 129 130
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
131 132 133 134 135 136 137 138 139 140
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const int kh = (index / filterWidth / outputHeight / outputWidth)
          % filterHeight;
    const int kw = (index / outputHeight / outputWidth) % filterWidth;
    const int h_in = -paddingH + h * strideH + kh;
    const int w_in = -paddingW + w * strideW + kw;
    if ((h_in >= 0) && (h_in < inputHeight)
          && (w_in >= 0) && (w_in < inputWidth)) {
      const int c = index / filterHeight / filterWidth / outputHeight / outputWidth;
141
      const int n = num_i;
142 143
      const int top_offset = ((n * outputChannels + c) * outputHeight + h)
            * outputWidth + w;
144
      const int bottom_offset = ((n * inputChannels + c) * inputHeight + h_in)
145 146
            * inputWidth + w_in;
      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
147 148 149 150 151 152 153 154 155
    } else {
      buffer_data[index] = 0;
    }
  }
}

template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public:
156
  void operator()(const T* inputData, 
157 158 159 160 161
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
162
			int inputChannels,
163 164
            int inputHeight,
            int inputWidth,
165 166 167 168 169 170 171 172
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
            T* outputData){

173 174
    int outputSize = batchSize * outputChannels * outputHeight * outputWidth;

175 176 177 178 179
    size_t blocks = (outputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
180

181
    ConvolutionDepthwiseForward<T>
182 183 184 185 186 187 188 189
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            outputSize, 
            inputData, 
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
190
			inputChannels,
191 192
            inputHeight,
            inputWidth,
193 194 195 196 197 198 199 200 201 202 203 204 205
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            outputData);
    }
};

template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public:
206
  void operator()(const T* outputGrad,
207 208 209 210 211
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
212
            int inputChannels,
213 214 215 216 217 218 219 220
            int inputHeight,
            int inputWidth,
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
221 222
            T* inputGrad){

223
    int inputSize = batchSize * inputChannels * inputHeight * inputWidth;
224 225 226 227 228 229 230

    size_t blocks = (inputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

231

232
    ConvolutionDepthwiseInputBackward<T>
233 234 235 236 237 238 239 240 241
          // NOLINT_NEXT_LINE(whitespace/operators)
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            inputSize,
            outputGrad,
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
242
			inputChannels,
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
            inputHeight,
            inputWidth,
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            inputGrad);
    }
};

template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public:
258
  void operator()(const T* outputGrad,
259 260 261 262 263
                const T* inputData,
                int batchSize,
                int outputChannels,
                int outputHeight,
                int outputWidth,
264
                int inputChannels,
265 266 267 268 269 270 271 272 273 274 275
                int inputHeight,
                int inputWidth,
                int filterHeight,
                int filterWidth,
                int strideH,
                int strideW,
                int paddingH,
                int paddingW,
                T* colData,
                T* filterGrad){

276 277
        int colDataSize = inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;

278 279 280 281 282
        size_t blocks = (colDataSize + 1024 -1) / 1024;
        size_t blockX = 512;
        size_t blockY = (blocks+512-1)/512;
        dim3 threads(1024, 1);
        dim3 grid(blockX, blockY);
283
        BaseMatrix filterGradMatrix(inputChannels * filterHeight * filterWidth, 1, filterGrad, false, true);
284

285
        for(int i = 0; i < batchSize; i++) {
286 287
            ConvolutionDepthwiseFilterBackward<T>
                <<< grid, threads, 0, STREAM_DEFAULT >>>(
288 289 290 291 292 293 294 295
                    i,
                    colDataSize,
                    outputGrad,
                    inputData,
                    batchSize,
                    outputChannels,
                    outputHeight,
                    outputWidth,
296
					inputChannels,
297 298 299 300 301 302 303 304 305
                    inputHeight,
                    inputWidth,
                    filterHeight,
                    filterWidth,
                    strideH,
                    strideW,
                    paddingH,
                    paddingW,
                    colData
306 307 308
            );
            int M = colDataSize / outputHeight / outputWidth;
            int K = outputHeight * outputWidth;
309 310

            BaseMatrix colMatrix(M, K, colData, false, true);
311
            filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0);	
312
		}
313 314 315
    }
};

316
#ifdef PADDLE_TYPE_DOUBLE
317 318 319
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
320
#else 
321 322 323
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
324
#endif
325 326

}  // namespace paddle