DepthwiseConvOpGpu.cu 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "DepthwiseConvOp.h"
16
#include "GemmFunctor.h"
17
#include "paddle/math/BaseMatrix.h"
18 19

namespace paddle {
20

21
// CUDA kernel to compute the depthwise convolution forward pass
22
template <class T>
23 24 25 26 27 28 29 30
__global__ 
void ConvolutionDepthwiseForward(const int nthreads,
    const T* const inputData, const T* const filterData,
    const int batchSize, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const outputData) {
31 32 33 34 35

  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  
  if(index < nthreads) {
36 37 38 39 40
    const int n = index / outputChannels / outputHeight / outputWidth;
    const int c = (index / outputHeight / outputWidth) % outputChannels;
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const T* weight = filterData + c * filterHeight * filterWidth;
41
    T value = 0;
42 43 44 45 46 47 48 49 50 51
	const int h_in_start = -paddingH + h * strideH;
	const int w_in_start = -paddingW + w * strideW;
	const int h_in_end = -paddingH + h * strideH + filterHeight - 1;
	const int w_in_end = -paddingW + w * strideW + filterWidth - 1;
    if ((h_in_start >= 0) && (h_in_end < inputHeight) 
		 &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
		for (int kh = 0; kh < filterHeight; ++kh) {
		  for (int kw = 0; kw < filterWidth; ++kw) {
			const int h_in = -paddingH + h * strideH + kh;
			const int w_in = -paddingW + w * strideW + kw;
52
			const int offset = ((n * outputChannels + c) * inputHeight + h_in)
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
					* inputWidth + w_in;
			  value += (*weight) * inputData[offset];
			++weight;
		  }
		}
	}else{
		for (int kh = 0; kh < filterHeight; ++kh) {
		  for (int kw = 0; kw < filterWidth; ++kw) {
			const int h_in = -paddingH + h * strideH + kh;
			const int w_in = -paddingW + w * strideW + kw;
			if ((h_in >= 0) && (h_in < inputHeight)
				  && (w_in >= 0) && (w_in < inputWidth)) {
			  const int offset = ((n * outputChannels + c) * inputHeight + h_in)
					* inputWidth + w_in;
			  value += (*weight) * inputData[offset];
			}
			++weight;
		  }
		}
	}
    outputData[index] = value;
74 75 76
  }
}

77
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
78
template <class T>
79 80
__global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
81
    const T* const top_diff, const T* const weight_data,
82 83 84 85 86
    const int num, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
     T* const bottom_diff) {
87 88 89
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if(index < nthreads) {
90 91 92 93 94
    const int n = index / outputChannels / inputHeight / inputWidth;
    const int c = (index / inputHeight / inputWidth) % outputChannels;
    const int h = (index / inputWidth) % inputHeight;
    const int w = index % inputWidth;
    const T* weight = weight_data + c * filterHeight * filterWidth;
95
    T value = 0;
96 97 98 99 100 101 102 103 104 105 106 107
    for (int kh = 0; kh < filterHeight; ++kh) {
      for (int kw = 0; kw < filterWidth; ++kw) {
        const int h_out_s = h + paddingH - kh;
        const int w_out_s = w + paddingW - kw;
        if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
          const int h_out = h_out_s / strideH;
          const int w_out = w_out_s / strideW;
	     // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
          if ((h_out >= 0) && (h_out < outputHeight)
                && (w_out >= 0) && (w_out < outputWidth)) {
            const int offset = ((n * outputChannels + c) * outputHeight + h_out)
                  * outputWidth + w_out;
108 109 110 111 112 113 114 115 116 117
            value += (*weight) * top_diff[offset];
          }
        }
        ++weight;
      }
    }
    bottom_diff[index] += value;
  }
}

118
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
119
template <class T>
120 121 122 123 124 125 126 127
__global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
    const T* const top_diff, const T* const inputData,
    const int num, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const buffer_data) {
128 129 130
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
131 132 133 134 135 136 137 138 139 140
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const int kh = (index / filterWidth / outputHeight / outputWidth)
          % filterHeight;
    const int kw = (index / outputHeight / outputWidth) % filterWidth;
    const int h_in = -paddingH + h * strideH + kh;
    const int w_in = -paddingW + w * strideW + kw;
    if ((h_in >= 0) && (h_in < inputHeight)
          && (w_in >= 0) && (w_in < inputWidth)) {
      const int c = index / filterHeight / filterWidth / outputHeight / outputWidth;
141
      const int n = num_i;
142 143 144 145 146
      const int top_offset = ((n * outputChannels + c) * outputHeight + h)
            * outputWidth + w;
      const int bottom_offset = ((n * outputChannels + c) * inputHeight + h_in)
            * inputWidth + w_in;
      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
147 148 149 150 151 152 153 154 155
    } else {
      buffer_data[index] = 0;
    }
  }
}

template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public:
156
  void operator()(const T* inputData, 
157 158 159 160 161
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
162 163
            int inputHeight,
            int inputWidth,
164 165 166 167 168 169 170 171
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
            T* outputData){

172 173
    int outputSize = batchSize * outputChannels * outputHeight * outputWidth;

174 175 176 177 178
    size_t blocks = (outputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
179

180
    ConvolutionDepthwiseForward<T>
181 182 183 184 185 186 187 188
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            outputSize, 
            inputData, 
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
189 190
            inputHeight,
            inputWidth,
191 192 193 194 195 196 197 198 199 200 201 202 203
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            outputData);
    }
};

template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public:
204
  void operator()(const T* outputGrad,
205 206 207 208 209
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
210
            int inputChannels,
211 212 213 214 215 216 217 218
            int inputHeight,
            int inputWidth,
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
219 220 221
            T* inputGrad){

	int inputSize = batchSize * inputChannels * inputHeight * inputWidth;
222 223 224 225 226 227 228

    size_t blocks = (inputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

229

230
    ConvolutionDepthwiseInputBackward<T>
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
          // NOLINT_NEXT_LINE(whitespace/operators)
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            inputSize,
            outputGrad,
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
            inputHeight,
            inputWidth,
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            inputGrad);
    }
};

template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public:
255
  void operator()(const T* outputGrad,
256 257 258 259 260
                const T* inputData,
                int batchSize,
                int outputChannels,
                int outputHeight,
                int outputWidth,
261
                int inputChannels,
262 263 264 265 266 267 268 269 270 271 272
                int inputHeight,
                int inputWidth,
                int filterHeight,
                int filterWidth,
                int strideH,
                int strideW,
                int paddingH,
                int paddingW,
                T* colData,
                T* filterGrad){

273 274
        int colDataSize = inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;

275 276 277 278 279
        size_t blocks = (colDataSize + 1024 -1) / 1024;
        size_t blockX = 512;
        size_t blockY = (blocks+512-1)/512;
        dim3 threads(1024, 1);
        dim3 grid(blockX, blockY);
280
		BaseMatrix filterGradMatrix(inputChannels * filterHeight * filterWidth, 1, filterGrad, false, true);
281

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
        for(int i = 0; i < batchSize; i++) {
			ConvolutionDepthwiseFilterBackward<T>
				<<< grid, threads, 0, STREAM_DEFAULT >>>(
                    i,
                    colDataSize,
                    outputGrad,
                    inputData,
                    batchSize,
                    outputChannels,
                    outputHeight,
                    outputWidth,
                    inputHeight,
                    inputWidth,
                    filterHeight,
                    filterWidth,
                    strideH,
                    strideW,
                    paddingH,
                    paddingW,
                    colData
				);
			int M = colDataSize / outputHeight / outputWidth;
			int K = outputHeight * outputWidth;
305 306 307

            BaseMatrix colMatrix(M, K, colData, false, true);
		    filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0);	
308
		}
309 310 311
    }
};

312
#ifdef PADDLE_TYPE_DOUBLE
313 314 315
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
316
#else 
317 318 319
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
320
#endif
321 322

}  // namespace paddle