DepthwiseConvOpGpu.cu 12.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "DepthwiseConvOp.h"
16
#include "GemmFunctor.h"
17
#include "paddle/math/BaseMatrix.h"
18 19

namespace paddle {
20

21
// CUDA kernel to compute the depthwise convolution forward pass
22
template <class T>
23 24 25 26
__global__ 
void ConvolutionDepthwiseForward(const int nthreads,
    const T* const inputData, const T* const filterData,
    const int batchSize, const int outputChannels, const int outputHeight,
27
    const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth,
28
    const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH,
29 30
    const int strideW, const int paddingH, const int paddingW,
    T* const outputData) {
31 32 33 34 35

  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  
  if(index < nthreads) {
36 37 38 39 40 41 42
    const int batch = index / outputChannels / outputHeight / outputWidth;
    const int c_out = (index / outputHeight / outputWidth) % outputChannels;
    const int h_out = (index / outputWidth) % outputHeight;
    const int w_out = index % outputWidth;

	const int c_in = c_out / filterMultiplier;
    const T* weight = filterData + c_out * filterHeight * filterWidth;
43
    T value = 0;
44 45 46 47
    const int h_in_start = -paddingH + h_out * strideH;
    const int w_in_start = -paddingW + w_out * strideW;
    const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1;
    const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1;
48 49
    if ((h_in_start >= 0) && (h_in_end < inputHeight) 
		 &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
50 51
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
52 53 54
                const int h_in = -paddingH + h_out * strideH + kh;
                const int w_in = -paddingW + w_out * strideW + kw;
                const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in)
55
					* inputWidth + w_in;
56 57
                value += (*weight) * inputData[offset];
                ++weight;
58 59
		  }
		}
60 61 62
    }else{
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
63 64
                const int h_in = -paddingH + h_out * strideH + kh;
                const int w_in = -paddingW + w_out * strideW + kw;
65 66
                if ((h_in >= 0) && (h_in < inputHeight)
                   && (w_in >= 0) && (w_in < inputWidth)) {
67
                    const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in)
68 69 70 71 72 73
                        * inputWidth + w_in;
                    value += (*weight) * inputData[offset];
                }
                ++weight;
            }
       }
74 75
	}
    outputData[index] = value;
76 77 78
  }
}

79
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
80
template <class T>
81 82
__global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
83
    const T* const top_diff, const T* const weight_data,
84
    const int num, const int outputChannels, const int outputHeight,
85
    const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth,
86
    const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH,
87 88
    const int strideW, const int paddingH, const int paddingW,
     T* const bottom_diff) {
89 90 91
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if(index < nthreads) {
92 93 94 95 96
    const int batch = index / inputChannels / inputHeight / inputWidth;
    const int c_in = (index / inputHeight / inputWidth) % inputChannels;
    const int h_in = (index / inputWidth) % inputHeight;
    const int w_in = index % inputWidth;
	const int c_out_start = c_in * filterMultiplier;
97
    T value = 0;
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
	for(int c_out = c_out_start; c_out < c_out_start + filterMultiplier; c_out ++){
	//weight bixu c_out
        const T* weight = weight_data + c_out * filterHeight * filterWidth;
        for (int kh = 0; kh < filterHeight; ++kh) {
            for (int kw = 0; kw < filterWidth; ++kw) {
                const int h_out_s = h_in + paddingH - kh;
                const int w_out_s = w_in + paddingW - kw;
                if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
                    const int h_out = h_out_s / strideH;
                    const int w_out = w_out_s / strideW;
	                // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
                    if ((h_out >= 0) && (h_out < outputHeight)
                        && (w_out >= 0) && (w_out < outputWidth)) {
                        const int offset = ((batch * outputChannels + c_out) * outputHeight + h_out)
                           * outputWidth + w_out;
                        value += (*weight) * top_diff[offset];
                    }
                }
                ++weight;
             }
118 119 120
        }
    }
    bottom_diff[index] += value;
121
   }
122 123
}

124
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
125
template <class T>
126 127 128 129
__global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
    const T* const top_diff, const T* const inputData,
    const int num, const int outputChannels, const int outputHeight,
130
    const int outputWidth, const int inputChannels, const int inputHeight, const int inputWidth,
131
    const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH,
132 133
    const int strideW, const int paddingH, const int paddingW,
    T* const buffer_data) {
134 135 136
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
137 138
    const int h_out = (index / outputWidth) % outputHeight;
    const int w_out = index % outputWidth;
139 140 141
    const int kh = (index / filterWidth / outputHeight / outputWidth)
          % filterHeight;
    const int kw = (index / outputHeight / outputWidth) % filterWidth;
142 143
    const int h_in = -paddingH + h_out * strideH + kh;
    const int w_in = -paddingW + w_out * strideW + kw;
144 145
    if ((h_in >= 0) && (h_in < inputHeight)
          && (w_in >= 0) && (w_in < inputWidth)) {
146 147 148 149 150 151
      const int c_out = index / filterHeight / filterWidth / outputHeight / outputWidth;
	  const int c_in = c_out / filterMultiplier;
      const int batch = num_i;
      const int top_offset = ((batch * outputChannels + c_out) * outputHeight + h_out)
            * outputWidth + w_out;
      const int bottom_offset = ((batch * inputChannels + c_in) * inputHeight + h_in)
152 153
            * inputWidth + w_in;
      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
154 155 156 157 158 159 160 161 162
    } else {
      buffer_data[index] = 0;
    }
  }
}

template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public:
163
  void operator()(const T* inputData, 
164 165 166 167 168
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
169
			int inputChannels,
170 171
            int inputHeight,
            int inputWidth,
172
            int filterMultiplier,
173 174 175 176 177 178 179 180
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
            T* outputData){

181 182
    int outputSize = batchSize * outputChannels * outputHeight * outputWidth;

183 184 185 186 187
    size_t blocks = (outputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
188

189
    ConvolutionDepthwiseForward<T>
190 191 192 193 194 195 196 197
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            outputSize, 
            inputData, 
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
198
			inputChannels,
199 200
            inputHeight,
            inputWidth,
201
            filterMultiplier,
202 203 204 205 206 207 208 209 210 211 212 213 214
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            outputData);
    }
};

template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public:
215
  void operator()(const T* outputGrad,
216 217 218 219 220
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
221
            int inputChannels,
222 223
            int inputHeight,
            int inputWidth,
224
            int filterMultiplier,
225 226 227 228 229 230
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
231 232
            T* inputGrad){

233
    int inputSize = batchSize * inputChannels * inputHeight * inputWidth;
234 235 236 237 238 239 240

    size_t blocks = (inputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

241

242
    ConvolutionDepthwiseInputBackward<T>
243 244 245 246 247 248 249 250 251
          // NOLINT_NEXT_LINE(whitespace/operators)
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            inputSize,
            outputGrad,
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
252
			inputChannels,
253 254
            inputHeight,
            inputWidth,
255
            filterMultiplier,
256 257 258 259 260 261 262 263 264 265 266 267 268
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            inputGrad);
    }
};

template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public:
269
  void operator()(const T* outputGrad,
270 271 272 273 274
                const T* inputData,
                int batchSize,
                int outputChannels,
                int outputHeight,
                int outputWidth,
275
                int inputChannels,
276 277
                int inputHeight,
                int inputWidth,
278
                int filterMultiplier,
279 280 281 282 283 284 285 286 287
                int filterHeight,
                int filterWidth,
                int strideH,
                int strideW,
                int paddingH,
                int paddingW,
                T* colData,
                T* filterGrad){

288
        int colDataSize = outputChannels * filterHeight * filterWidth * outputHeight * outputWidth;
289

290 291 292 293 294
        size_t blocks = (colDataSize + 1024 -1) / 1024;
        size_t blockX = 512;
        size_t blockY = (blocks+512-1)/512;
        dim3 threads(1024, 1);
        dim3 grid(blockX, blockY);
295
        BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth, 1, filterGrad, false, true);
296

297
        for(int i = 0; i < batchSize; i++) {
298 299
            ConvolutionDepthwiseFilterBackward<T>
                <<< grid, threads, 0, STREAM_DEFAULT >>>(
300 301 302 303 304 305 306 307
                    i,
                    colDataSize,
                    outputGrad,
                    inputData,
                    batchSize,
                    outputChannels,
                    outputHeight,
                    outputWidth,
308
					inputChannels,
309 310
                    inputHeight,
                    inputWidth,
311
					filterMultiplier,
312 313 314 315 316 317 318
                    filterHeight,
                    filterWidth,
                    strideH,
                    strideW,
                    paddingH,
                    paddingW,
                    colData
319 320
            );
            int K = outputHeight * outputWidth;
321
            int M = colDataSize / K;
322 323

            BaseMatrix colMatrix(M, K, colData, false, true);
324
            filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0);	
325
		}
326 327 328
    }
};

329
#ifdef PADDLE_TYPE_DOUBLE
330 331 332
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
333
#else 
334 335 336
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
337
#endif
338 339

}  // namespace paddle