DepthwiseConvOpGpu.cu 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "ConvOp.h"
#include "DepthwiseConvOp.h"
17 18
#include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h"
19 20 21

namespace paddle {
template <class T>
22 23 24 25 26 27 28 29
__global__ 
void ConvolutionDepthwiseForward(const int nthreads,
    const T* const inputData, const T* const filterData,
    const int batchSize, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const outputData) {
30 31 32 33 34

  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  
  if(index < nthreads) {
35 36 37 38 39
    const int n = index / outputChannels / outputHeight / outputWidth;
    const int c = (index / outputHeight / outputWidth) % outputChannels;
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const T* weight = filterData + c * filterHeight * filterWidth;
40
    T value = 0;
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
	const int h_in_start = -paddingH + h * strideH;
	const int w_in_start = -paddingW + w * strideW;
	const int h_in_end = -paddingH + h * strideH + filterHeight - 1;
	const int w_in_end = -paddingW + w * strideW + filterWidth - 1;
    if ((h_in_start >= 0) && (h_in_end < inputHeight) 
		 &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
		for (int kh = 0; kh < filterHeight; ++kh) {
		  for (int kw = 0; kw < filterWidth; ++kw) {
			const int h_in = -paddingH + h * strideH + kh;
			const int w_in = -paddingW + w * strideW + kw;
			  const int offset = ((n * outputChannels + c) * inputHeight + h_in)
					* inputWidth + w_in;
			  value += (*weight) * inputData[offset];
			++weight;
		  }
		}
	}else{
		for (int kh = 0; kh < filterHeight; ++kh) {
		  for (int kw = 0; kw < filterWidth; ++kw) {
			const int h_in = -paddingH + h * strideH + kh;
			const int w_in = -paddingW + w * strideW + kw;
			if ((h_in >= 0) && (h_in < inputHeight)
				  && (w_in >= 0) && (w_in < inputWidth)) {
			  const int offset = ((n * outputChannels + c) * inputHeight + h_in)
					* inputWidth + w_in;
			  value += (*weight) * inputData[offset];
			}
			++weight;
		  }
		}
	}
    outputData[index] = value;
73 74 75 76
  }
}

template <class T>
77 78
__global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
79
    const T* const top_diff, const T* const weight_data,
80 81 82 83 84
    const int num, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
     T* const bottom_diff) {
85 86 87
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if(index < nthreads) {
88 89 90 91 92
    const int n = index / outputChannels / inputHeight / inputWidth;
    const int c = (index / inputHeight / inputWidth) % outputChannels;
    const int h = (index / inputWidth) % inputHeight;
    const int w = index % inputWidth;
    const T* weight = weight_data + c * filterHeight * filterWidth;
93
    T value = 0;
94 95 96 97 98 99 100 101 102 103 104 105
    for (int kh = 0; kh < filterHeight; ++kh) {
      for (int kw = 0; kw < filterWidth; ++kw) {
        const int h_out_s = h + paddingH - kh;
        const int w_out_s = w + paddingW - kw;
        if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
          const int h_out = h_out_s / strideH;
          const int w_out = w_out_s / strideW;
	     // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
          if ((h_out >= 0) && (h_out < outputHeight)
                && (w_out >= 0) && (w_out < outputWidth)) {
            const int offset = ((n * outputChannels + c) * outputHeight + h_out)
                  * outputWidth + w_out;
106 107 108 109 110 111 112 113 114 115 116
            value += (*weight) * top_diff[offset];
          }
        }
        ++weight;
      }
    }
    bottom_diff[index] += value;
  }
}

template <class T>
117 118 119 120 121 122 123 124
__global__
void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
    const T* const top_diff, const T* const inputData,
    const int num, const int outputChannels, const int outputHeight,
    const int outputWidth, const int inputHeight, const int inputWidth,
    const int filterHeight, const int filterWidth, const int strideH,
    const int strideW, const int paddingH, const int paddingW,
    T* const buffer_data) {
125 126 127
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < nthreads) {
128 129 130 131 132 133 134 135 136 137
    const int h = (index / outputWidth) % outputHeight;
    const int w = index % outputWidth;
    const int kh = (index / filterWidth / outputHeight / outputWidth)
          % filterHeight;
    const int kw = (index / outputHeight / outputWidth) % filterWidth;
    const int h_in = -paddingH + h * strideH + kh;
    const int w_in = -paddingW + w * strideW + kw;
    if ((h_in >= 0) && (h_in < inputHeight)
          && (w_in >= 0) && (w_in < inputWidth)) {
      const int c = index / filterHeight / filterWidth / outputHeight / outputWidth;
138
      const int n = num_i;
139 140 141 142 143
      const int top_offset = ((n * outputChannels + c) * outputHeight + h)
            * outputWidth + w;
      const int bottom_offset = ((n * outputChannels + c) * inputHeight + h_in)
            * inputWidth + w_in;
      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    } else {
      buffer_data[index] = 0;
    }
  }
}

template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T>{
public:
  void operator()(int outputSize, 
            const T* inputData, 
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
160 161
			int inputHeight,
			int inputWidth,
162 163 164 165 166 167 168 169 170 171 172 173 174 175
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
            T* outputData){

    size_t blocks = (outputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
    
176
    ConvolutionDepthwiseForward<T>
177 178 179 180 181 182 183 184
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            outputSize, 
            inputData, 
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
185 186
			inputHeight,
			inputWidth,
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            outputData);
    }
};

template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T>{
public:
  void operator()(int inputSize,
            const T* outputGrad,
            const T* filterData,
            int batchSize,
            int outputChannels,
            int outputHeight,
            int outputWidth,
            int inputHeight,
            int inputWidth,
            int filterHeight,
            int filterWidth,
            int strideH,
            int strideW,
            int paddingH,
            int paddingW,
                T* inputGrad){

    size_t blocks = (inputSize + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

223
    ConvolutionDepthwiseInputBackward<T>
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
          // NOLINT_NEXT_LINE(whitespace/operators)
        <<< grid, threads, 0, STREAM_DEFAULT >>>(
            inputSize,
            outputGrad,
            filterData,
            batchSize,
            outputChannels,
            outputHeight,
            outputWidth,
            inputHeight,
            inputWidth,
            filterHeight,
            filterWidth,
            strideH,
            strideW,
            paddingH,
            paddingW,
            inputGrad);
    }
};

template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public:
  void operator()(int num_i,
                int colDataSize,
                const T* outputGrad,
                const T* inputData,
                int batchSize,
                int outputChannels,
                int outputHeight,
                int outputWidth,
                int inputHeight,
                int inputWidth,
                int filterHeight,
                int filterWidth,
                int strideH,
                int strideW,
                int paddingH,
                int paddingW,
                T* colData,
                T* multiplierData,
                T* filterGrad){

        size_t blocks = (colDataSize + 1024 -1) / 1024;
        size_t blockX = 512;
        size_t blockY = (blocks+512-1)/512;
        dim3 threads(1024, 1);
        dim3 grid(blockX, blockY);

274
	    ConvolutionDepthwiseFilterBackward<T>
275
            <<< grid, threads, 0, STREAM_DEFAULT >>>(
276 277
                num_i,
                colDataSize,
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
                outputGrad,
                inputData,
                batchSize,
                outputChannels,
                outputHeight,
                outputWidth,
                inputHeight,
                inputWidth,
                filterHeight,
                filterWidth,
                strideH,
                strideW,
                paddingH,
                paddingW,
                colData
            );
294 295
        GemmFunctor<DEVICE_TYPE_GPU, real> gemm;
        int M = colDataSize / outputHeight / outputWidth;
296 297 298 299 300 301 302
        int N = 1;
        int K = outputHeight * outputWidth;
        gemm(CblasNoTrans,
            CblasNoTrans,
            M,
            N,
            K,
303
            (T)1.0,
304 305 306 307
            colData,
            K,
            multiplierData,
            N,
308
            (T)1.0,
309 310 311 312 313 314
            filterGrad,
            N);
        //gemv
    }
};

315 316 317 318 319 320 321 322
#ifdef PADDLE_TYPE_DOUBLE
using real=double;
#else 
using real=float;
#endif
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, real>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, real>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, real>;
323 324

}  // namespace paddle