DepthwiseConvOpGpu.cu 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "DepthwiseConvOp.h"
16
#include "GemmFunctor.h"
17 18

namespace paddle {
19
// CUDA kernel to compute the depthwise convolution forward pass
20
template <class T>
21 22 23 24 25 26 27 28
__global__ 
void ConvolutionDepthwiseForward(const int nthreads,
    const T* const inputData, const T* const filterData,
    const int batchSize, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const outputData) {
29 30 31 32 33

  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  
  if(index < nthreads) {
34 35 36 37 38
    const int n = index / outputChannels / outputHeight / outputWidth;
    const int c = (index / outputHeight / outputWidth) % outputChannels;
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const T* weight = filterData + c * filterHeight * filterWidth;
39
    T value = 0;
40 41 42 43 44 45 46 47 48 49
	const int h_in_start = -paddingH + h * strideH;
	const int w_in_start = -paddingW + w * strideW;
	const int h_in_end = -paddingH + h * strideH + filterHeight - 1;
	const int w_in_end = -paddingW + w * strideW + filterWidth - 1;
    if ((h_in_start >= 0) && (h_in_end < inputHeight) 
		 &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
		for (int kh = 0; kh < filterHeight; ++kh) {
		  for (int kw = 0; kw < filterWidth; ++kw) {
			const int h_in = -paddingH + h * strideH + kh;
			const int w_in = -paddingW + w * strideW + kw;
50
			const int offset = ((n * outputChannels + c) * inputHeight + h_in)
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
					* inputWidth + w_in;
			  value += (*weight) * inputData[offset];
			++weight;
		  }
		}
	}else{
		for (int kh = 0; kh < filterHeight; ++kh) {
		  for (int kw = 0; kw < filterWidth; ++kw) {
			const int h_in = -paddingH + h * strideH + kh;
			const int w_in = -paddingW + w * strideW + kw;
			if ((h_in >= 0) && (h_in < inputHeight)
				  && (w_in >= 0) && (w_in < inputWidth)) {
			  const int offset = ((n * outputChannels + c) * inputHeight + h_in)
					* inputWidth + w_in;
			  value += (*weight) * inputData[offset];
			}
			++weight;
		  }
		}
	}
    outputData[index] = value;
72 73 74
  }
}

75
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
76
template <class T>
77 78
__global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
79
    const T* const top_diff, const T* const weight_data,
80 81 82 83 84
    const int num, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
     T* const bottom_diff) {
85 86 87
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if(index < nthreads) {
88 89 90 91 92
    const int n = index / outputChannels / inputHeight / inputWidth;
    const int c = (index / inputHeight / inputWidth) % outputChannels;
    const int h = (index / inputWidth) % inputHeight;
    const int w = index % inputWidth;
    const T* weight = weight_data + c * filterHeight * filterWidth;
93
    T value = 0;
94 95 96 97 98 99 100 101 102 103 104 105
    for (int kh = 0; kh < filterHeight; ++kh) {
      for (int kw = 0; kw < filterWidth; ++kw) {
        const int h_out_s = h + paddingH - kh;
        const int w_out_s = w + paddingW - kw;
        if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
          const int h_out = h_out_s / strideH;
          const int w_out = w_out_s / strideW;
	     // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
          if ((h_out >= 0) && (h_out < outputHeight)
                && (w_out >= 0) && (w_out < outputWidth)) {
            const int offset = ((n * outputChannels + c) * outputHeight + h_out)
                  * outputWidth + w_out;
106 107 108 109 110 111 112 113 114 115
            value += (*weight) * top_diff[offset];
          }
        }
        ++weight;
      }
    }
    bottom_diff[index] += value;
  }
}

116
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
117
template <class T>
118 119 120 121 122 123 124 125
__global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
    const T* const top_diff, const T* const inputData,
    const int num, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const buffer_data) {
126 127 128
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
129 130 131 132 133 134 135 136 137 138
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const int kh = (index / filterWidth / outputHeight / outputWidth)
          % filterHeight;
    const int kw = (index / outputHeight / outputWidth) % filterWidth;
    const int h_in = -paddingH + h * strideH + kh;
    const int w_in = -paddingW + w * strideW + kw;
    if ((h_in >= 0) && (h_in < inputHeight)
          && (w_in >= 0) && (w_in < inputWidth)) {
      const int c = index / filterHeight / filterWidth / outputHeight / outputWidth;
139
      const int n = num_i;
140 141 142 143 144
      const int top_offset = ((n * outputChannels + c) * outputHeight + h)
            * outputWidth + w;
      const int bottom_offset = ((n * outputChannels + c) * inputHeight + h_in)
            * inputWidth + w_in;
      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
145 146 147 148 149 150 151 152 153
    } else {
      buffer_data[index] = 0;
    }
  }
}

template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public:
154
  void operator()(const T* inputData, 
155 156 157 158 159
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
160 161
            int inputHeight,
            int inputWidth,
162 163 164 165 166 167 168 169
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
            T* outputData){

170 171
    int outputSize = batchSize * outputChannels * outputHeight * outputWidth;

172 173 174 175 176
    size_t blocks = (outputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
177

178
    ConvolutionDepthwiseForward<T>
179 180 181 182 183 184 185 186
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            outputSize, 
            inputData, 
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
187 188
            inputHeight,
            inputWidth,
189 190 191 192 193 194 195 196 197 198 199 200 201
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            outputData);
    }
};

template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public:
202
  void operator()(const T* outputGrad,
203 204 205 206 207
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
208
            int inputChannels,
209 210 211 212 213 214 215 216
            int inputHeight,
            int inputWidth,
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
217 218 219
            T* inputGrad){

	int inputSize = batchSize * inputChannels * inputHeight * inputWidth;
220 221 222 223 224 225 226

    size_t blocks = (inputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

227

228
    ConvolutionDepthwiseInputBackward<T>
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
          // NOLINT_NEXT_LINE(whitespace/operators)
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            inputSize,
            outputGrad,
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
            inputHeight,
            inputWidth,
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            inputGrad);
    }
};

template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public:
253
  void operator()(const T* outputGrad,
254 255 256 257 258
                const T* inputData,
                int batchSize,
                int outputChannels,
                int outputHeight,
                int outputWidth,
259
                int inputChannels,
260 261 262 263 264 265 266 267 268 269 270 271
                int inputHeight,
                int inputWidth,
                int filterHeight,
                int filterWidth,
                int strideH,
                int strideW,
                int paddingH,
                int paddingW,
                T* colData,
                T* multiplierData,
                T* filterGrad){

272 273
        int colDataSize = inputChannels * filterHeight * filterWidth * outputHeight * outputWidth;

274 275 276 277 278 279
        size_t blocks = (colDataSize + 1024 -1) / 1024;
        size_t blockX = 512;
        size_t blockY = (blocks+512-1)/512;
        dim3 threads(1024, 1);
        dim3 grid(blockX, blockY);

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
        for(int i = 0; i < batchSize; i++) {
			ConvolutionDepthwiseFilterBackward<T>
				<<< grid, threads, 0, STREAM_DEFAULT >>>(
                    i,
                    colDataSize,
                    outputGrad,
                    inputData,
                    batchSize,
                    outputChannels,
                    outputHeight,
                    outputWidth,
                    inputHeight,
                    inputWidth,
                    filterHeight,
                    filterWidth,
                    strideH,
                    strideW,
                    paddingH,
                    paddingW,
                    colData
				);
			GemmFunctor<DEVICE_TYPE_GPU, real> gemm;
			int M = colDataSize / outputHeight / outputWidth;
			int N = 1;
			int K = outputHeight * outputWidth;
			gemm(CblasNoTrans,
				CblasNoTrans,
				M,
				N,
				K,
				(T)1.0,
				colData,
				K,
				multiplierData,
				N,
				(T)1.0,
				filterGrad,
				N);
		}
319 320 321 322
        //gemv
    }
};

323
#ifdef PADDLE_TYPE_DOUBLE
324 325 326
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
327
#else 
328 329 330
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
331
#endif
332 333

}  // namespace paddle