hl_cuda_cnn.cu 18.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#include <float.h>
#include "hl_base.h"
#include "hl_cnn.h"
L
liaogang 已提交
19
#include "hl_device_functions.cuh"
Z
zhangjinchao01 已提交
20

21 22 23 24 25 26 27
__global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
                                 const int channels, const int height,
                                 const int width,
                                 const int pooledH, const int pooledW,
                                 const int ksizeW, const int ksizeH,
                                 const int strideH, const int strideW,
                                 const int offsetH, const int offsetW,
Q
qijun 已提交
28
                                 real* tgtData, const int tgtStride) {
29
  int index =  blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
30 31 32 33
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
34 35 36 37 38 39 40
    int frameNum = index / pooledW / pooledH / channels;
    int hstart = ph * strideH - offsetH;
    int wstart = pw * strideW - offsetW;
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
Z
zhangjinchao01 已提交
41 42 43 44 45 46 47 48
    real maxval = -FLT_MAX;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        if (maxval < inputData[h * width + w])
          maxval = inputData[h * width + w];
      }
    }
Q
qijun 已提交
49 50 51
    int tgtIndex = index % (pooledW * pooledH * channels) +
        frameNum * tgtStride;
    tgtData[tgtIndex] = maxval;
Z
zhangjinchao01 已提交
52 53 54
  }
}

55 56 57 58 59 60 61
void hl_maxpool_forward(const int frameCnt, const real* inputData,
                        const int channels,
                        const int height, const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
                        const int paddingH, const int paddingW,
Q
qijun 已提交
62
                        real* tgtData, const int tgtStride) {
63 64 65

  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
66
  dim3 threads(1024, 1);
67 68
  dim3 grid(blocks, 1);

Z
zhangjinchao01 已提交
69 70
  KeMaxPoolForward<<< grid, threads, 0, STREAM_DEFAULT >>>
           (num_kernels, inputData, channels, height, width,
71
           pooledH, pooledW, sizeX, sizeY, strideH, strideW,
Q
qijun 已提交
72
           paddingH, paddingW, tgtData, tgtStride);
Z
zhangjinchao01 已提交
73 74 75
  CHECK_SYNC("hl_maxpool_forward failed");
}

76
__global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
Z
zhangjinchao01 已提交
77
                                  const real* outData, const real* outGrad,
78 79 80 81 82 83 84
                                  const int channels, const int height,
                                  const int width,
                                  const int pooledH, const int pooledW,
                                  const int sizeX, const int sizeY,
                                  const int strideH, const int strideW,
                                  const int padH, const int padW,
                                  real scaleA, real scaleB,
Q
qijun 已提交
85
                                  real* targetGrad, const int outStride) {
86
  int index = blockIdx.x  * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
87 88 89
  if (index < nthreads) {
    // find out the local index
    // find out the local offset
90 91
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
92
    int offsetC = (index / width / height) % channels;
93 94 95 96 97 98

    int frameNum = index / width / height / channels;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
99 100
    real gradient = 0;
    real input = inputData[index];
Q
qijun 已提交
101 102
    outData += (frameNum * outStride + offsetC * pooledH * pooledW);
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
Z
zhangjinchao01 已提交
103 104 105 106 107 108 109 110 111 112 113 114
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        if (input == outData[ph * pooledW + pw]) {
          gradient += outGrad[ph * pooledW + pw];
        }
      }
    }
    targetGrad[index] =
      scaleB * targetGrad[index] + scaleA * gradient;
  }
}

115
void hl_maxpool_backward(const int frameCnt, const real* inputData,
Z
zhangjinchao01 已提交
116
                        const real* outData, const real* outGrad,
117 118 119 120 121 122 123
                        const int channels, const int height,
                        const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
                        const int paddingH, const int paddingW,
                        real scaleA, real scaleB,
Q
qijun 已提交
124
                        real* targetGrad, const int outStride) {
125 126 127

  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
128

129
  KeMaxPoolBackward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
130
           (num_kernels, inputData, outData, outGrad, channels,
131 132 133 134
           height, width, pooledH, pooledW, sizeX, sizeY,
           strideH, strideW,
           paddingH, paddingW,
           scaleA, scaleB,
Q
qijun 已提交
135
           targetGrad, outStride);
Z
zhangjinchao01 已提交
136 137 138
  CHECK_SYNC("hl_maxpool_backward");
}

139 140 141 142 143 144 145
__global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
                                 const int channels,
                                 const int height, const int width,
                                 const int pooledH, const int pooledW,
                                 const int sizeX, const int sizeY,
                                 const int strideH, const int strideW,
                                 const int padH, const int padW,
Q
qijun 已提交
146
                                 real* tgtData, const int tgtStride) {
147
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
148 149 150 151
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
152 153 154 155 156 157 158 159 160 161 162 163
    int frameNum = index / pooledW / pooledH / channels;

    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
    int hend = min(hstart + sizeY, height + padH);
    int wend = min(wstart + sizeX, width + padW);
    int pool_size = (hend - hstart) * (wend - wstart);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    hend = min(hend, height);
    wend = min(wend, width);

Z
zhangjinchao01 已提交
164 165 166 167 168 169 170
    real aveval = 0;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        aveval += inputData[h * width + w];
      }
    }
Q
qijun 已提交
171 172 173
    int tgtIndex = index % (pooledW * pooledH * channels) +
        frameNum * tgtStride;
    tgtData[tgtIndex] = aveval / pool_size;
Z
zhangjinchao01 已提交
174 175 176
  }
}

177 178 179 180 181 182
void hl_avgpool_forward(const int frameCnt, const real* inputData,
                        const int channels,
                        const int height, const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
Q
qijun 已提交
183 184
                        const int paddingH, const int paddingW, 
                        real* tgtData, const int tgtStride) {
185 186 187
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  KeAvgPoolForward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
188 189
           (num_kernels, inputData, channels,
           height, width, pooledH, pooledW,
190
           sizeX, sizeY, strideH, strideW,
Q
qijun 已提交
191
           paddingH, paddingW, tgtData, tgtStride);
Z
zhangjinchao01 已提交
192 193 194
  CHECK_SYNC("hl_avgpool_forward failed");
}

195 196 197 198 199 200 201 202
__global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
                                  const int channels, const int height,
                                  const int width,
                                  const int pooledH, const int pooledW,
                                  const int sizeX, const int sizeY,
                                  const int strideH, const int strideW,
                                  const int padH, const int padW,
                                  real scaleA, real scaleB,
Q
qijun 已提交
203
                                  real* tgtGrad, const int outStride) {
204
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
205
  if (index < nthreads) {
206 207
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
208
    int offsetC = (index / width / height) % channels;
209 210 211 212 213 214
    int frameNum = index / width / height / channels;

    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
215
    real gradient = 0;
Q
qijun 已提交
216 217
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);

Z
zhangjinchao01 已提交
218 219 220 221

    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        // figure out the pooling size
222 223 224 225 226
        int hstart = ph * strideH - padH;
        int wstart = pw * strideW - padW;
        int hend = min(hstart + sizeY, height + padH);
        int wend = min(wstart + sizeX, width + padW);
        int poolsize = (hend - hstart) * (wend - wstart);
Z
zhangjinchao01 已提交
227 228 229 230 231 232 233
        gradient += outGrad[ph * pooledW + pw]/poolsize;
      }
    }
    tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
  }
}

234 235 236 237 238 239 240 241
void hl_avgpool_backward(const int frameCnt, const real* outGrad,
                         const int channels,
                         const int height, const int width,
                         const int pooledH, const int pooledW,
                         const int sizeX, const int sizeY,
                         const int strideH, const int strideW,
                         const int paddingH, const int paddingW,
                         real scaleA, real scaleB,
Q
qijun 已提交
242
                         real* backGrad, const int outStride) {
243 244
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
245

246
  KeAvgPoolBackward <<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
247
           (num_kernels, outGrad, channels, height, width,
248 249 250 251
           pooledH, pooledW, sizeX, sizeY,
           strideH, strideW,
           paddingH, paddingW,
           scaleA, scaleB,
Q
qijun 已提交
252
           backGrad, outStride);
Z
zhangjinchao01 已提交
253 254 255
  CHECK_SYNC("hl_avgpool_backward failed");
}

L
liaogang 已提交
256
__global__ void KeBilinearInterpFw(const real* in,
L
liaogang 已提交
257 258 259 260 261 262 263 264 265 266 267 268
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
269
  int nthreads = outputH * outputW;                      
L
liaogang 已提交
270
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

    const real* inPos =
      &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];

    // bilinear interpolation
    out[outIdH * outputW + outIdW] =
      h2lambda * (w2lambda * inPos[0]            + w1lambda * inPos[wId]) + 
      h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]);
L
liaogang 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309
  }
}

void hl_bilinear_forward(const real* inData,
                         const size_t inImgH,
                         const size_t inImgW,
                         const size_t inputH,
                         const size_t inputW,
                         real* outData,
                         const size_t outImgH,
                         const size_t outImgW,
                         const size_t outputH,
                         const size_t outputW,
L
liaogang 已提交
310 311 312 313
                         const size_t numChannels,
                         const real ratioH,
                         const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
314 315 316
  int blocks = (threadNum + 1024 - 1) / 1024;

  KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
L
liaogang 已提交
317 318
    inData, inImgH, inImgW, inputH, inputW, outData, outImgH,
    outImgW, outputH, outputW, numChannels, ratioH, ratioW);
L
liaogang 已提交
319 320 321
  CHECK_SYNC("hl_bilinear_forward failed");
}

L
liaogang 已提交
322
__global__ void KeBilinearInterpBw(real* in,
L
liaogang 已提交
323 324 325 326 327 328 329 330 331 332 333 334
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   const real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
335
  int nthreads = outputH * outputW;
L
liaogang 已提交
336
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

    real* inPos =
      &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
358
    const real* outPos = &out[outIdH * outputW + outIdW];
L
liaogang 已提交
359 360 361 362
    paddle::paddleAtomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
L
liaogang 已提交
363 364 365 366 367 368 369 370 371 372 373 374 375
  }
}

void hl_bilinear_backward(real* inGrad,
                          const size_t inImgH,
                          const size_t inImgW,
                          const size_t inputH,
                          const size_t inputW,
                          const real* outGrad,
                          const size_t outImgH,
                          const size_t outImgW,
                          const size_t outputH,
                          const size_t outputW,
L
liaogang 已提交
376 377 378 379
                          const size_t numChannels,
                          const real ratioH,
                          const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
380 381 382
  int blocks = (threadNum + 1024 - 1) / 1024;

  KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
L
liaogang 已提交
383 384
    inGrad, inImgH, inImgW, inputH, inputW, outGrad, outImgH,
    outImgW, outputH, outputW, numChannels, ratioH, ratioW);
L
liaogang 已提交
385
  CHECK_SYNC("hl_bilinear_backward failed");
L
liaogang 已提交
386 387
}

388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
                                real * outData, int* idData, 
                                size_t size, size_t featLen, size_t groups) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index < nthreads) {
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
    size_t data_idx = (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
    real max = inData[data_idx];
    int maxId = 0;
    for (size_t g = 1; g < groups; ++g) {
      real tmp = inData[data_idx + g * featLen];
      if (tmp > max) {
        max = tmp;
        maxId = g;
      }
    }
    outData[index] = max;
    idData[index] = maxId;
  }
}

void hl_maxout_forward(const real* inData, real* outData,
                       int* idData, size_t batchSize, size_t size,
                       size_t featLen, size_t groups) {
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  maxoutFpCompute<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
    num_kernels, inData, outData, idData, size, featLen, groups);
  CHECK_SYNC("hl_maxout_forward failed");
}

__global__ void maxoutBpCompute(size_t nthreads, real* inGrad,
                                const real* outGrad, const int* idData,
                                size_t size, size_t featLen, size_t groups) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index < nthreads) {
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
    size_t newIndex = batch_idx * size;
    size_t gradIdx = (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
    (inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i];
  }
}

void hl_maxout_backward(real* inGrad, const real* outGrad,
                        const int* idData, size_t batchSize, size_t size,
                        size_t featLen, size_t groups) {
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  maxoutBpCompute<<< blocks, 1024, 0, STREAM_DEFAULT >>>(
    num_kernels, inGrad, outGrad, idData, size, featLen, groups);
  CHECK_SYNC("hl_maxout_backward failed");
}