hl_cuda_cnn.cu 23.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#include <float.h>
#include "hl_base.h"
#include "hl_cnn.h"
L
liaogang 已提交
19
#include "hl_device_functions.cuh"
Z
zhangjinchao01 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

__global__ void KeFeature2col(size_t n, size_t height, const real* data_im,
                              size_t blockH, size_t blockW, size_t width,
                              size_t strideH, size_t strideW,
                              size_t paddingH, size_t paddingW,
                              size_t height_col, size_t width_col,
                              real* data_col) {
  size_t index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < n) {
    size_t w_out = index % width_col;
    index /= width_col;
    size_t h_out = index % height_col;
    size_t channel_in = index / height_col;
    size_t channel_out = channel_in * blockH * blockW;
    size_t h_in = h_out * strideH;
    size_t w_in = w_out * strideW;

    data_col += (channel_out * height_col + h_out) * width_col + w_out;
    for (size_t i = 0; i < blockH; ++i) {
      for (size_t j = 0; j < blockW; ++j) {
        int rIdx = int(h_in+i);
        int cIdx = int(w_in+j);
        if ((rIdx-(int)paddingH) >= (int)height ||
            (rIdx-(int)paddingH) < 0 ||
            (cIdx-(int)paddingW) >= (int)width ||
            (cIdx-(int)paddingW) < 0) {
          *data_col = 0;
        } else {
          rIdx = rIdx + channel_in*height - paddingH;
          cIdx = cIdx - paddingW;
          *data_col = data_im[rIdx* width + cIdx];
        }
        data_col += height_col * width_col;
      }
    }
  }
}

void hl_expand_feature2col(const real* dataIm, size_t channels,
                           size_t height, size_t width,
                           size_t blockH, size_t blockW,
                           size_t strideH, size_t strideW,
                           size_t paddingH, size_t paddingW,
                           size_t outputH, size_t outputW,
                           real* dataCol) {
  size_t numKernels = channels * outputH * outputW;

  size_t blocks = (numKernels + 1024 -1) / 1024;
  size_t blockX = 512;
  size_t blockY = (blocks+512-1)/512;
  dim3 threads(1024, 1);
  dim3 grid(blockX, blockY);
  KeFeature2col<<< grid, threads, 0, STREAM_DEFAULT >>>
           (numKernels, height, dataIm, blockH, blockW, width,
           strideH, strideW, paddingH, paddingW,
           outputH, outputW, dataCol);
  CHECK_SYNC("hl_expand_feature2col failed");
}

__global__ void KeCol2Feature(size_t n, const real* data_col, size_t height,
                              size_t width, size_t channels,
                              size_t blockH, size_t blockW,
                              size_t strideH, size_t strideW,
                              size_t paddingH, size_t paddingW,
                              size_t height_col, size_t width_col,
                              real* data_im, real alpha, real beta) {
  size_t index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < n) {
    real val = 0;
    int w = int(index % width);
    int h = int((index / width) % height);
    int c = int(index / (width * height));
    if ((w - (int)paddingW) >= 0 &&
        (w - (int)paddingW) < (width-2 * paddingW) &&
        (h - (int)paddingH) >= 0 &&
        (h - paddingH) < (height - 2 * paddingH)) {
      // compute the start and end of the output
      int w_col_start =
        (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
      int w_col_end =
        min((int)(w / (int)strideW + 1), (int)(width_col));
      int h_col_start =
        (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
      int h_col_end = min(int(h / strideH + 1), int(height_col));
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
          // the col location: [c * width * height + h_out, w_out]
          int c_col = int(c * blockH* blockW) + \
            (h - h_col * (int)strideH) * (int)blockW +
            (w - w_col * (int)strideW);
          val += data_col[(c_col * height_col + h_col) * width_col + w_col];
        }
      }
      h -= paddingH;
      w -= paddingW;
      real tD = data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
                          h*(width-2*paddingW) + w];
      data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
              h*(width-2*paddingW) + w] = alpha * val + beta*tD;
    }
  }
}

void hl_shrink_col2feature(const real * dataCol, size_t channels,
                           size_t height, size_t width,
                           size_t blockH, size_t blockW,
                           size_t strideH, size_t strideW,
                           size_t paddingH, size_t paddingW,
                           size_t outputH, size_t outputW,
                           real* dataIm, real alpha, real beta) {
  size_t numKernels = channels * (height + 2*paddingH) * (width + 2*paddingW);

  size_t blocks = (numKernels + 1024 -1) / 1024;
  size_t blockX = 512;
  size_t blockY = (blocks+512-1)/512;
  dim3 threads(1024, 1);
  dim3 grid(blockX, blockY);

  // To avoid involving atomic operations, we will launch one kernel per
  // bottom dimension, and then in the kernel add up the top dimensions.
  KeCol2Feature<<< grid, threads, 0, STREAM_DEFAULT >>>
           (numKernels, dataCol, height + 2*paddingH, width + 2*paddingW,
           channels, blockH, blockW, strideH, strideW, paddingH, paddingW,
           outputH, outputW, dataIm, alpha, beta);
  CHECK_SYNC("hl_shrink_col2feature failed");
}

149 150 151 152 153 154 155
__global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
                                 const int channels, const int height,
                                 const int width,
                                 const int pooledH, const int pooledW,
                                 const int ksizeW, const int ksizeH,
                                 const int strideH, const int strideW,
                                 const int offsetH, const int offsetW,
Q
qijun 已提交
156
                                 real* tgtData, const int tgtStride) {
157
  int index =  blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
158 159 160 161
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
162 163 164 165 166 167 168
    int frameNum = index / pooledW / pooledH / channels;
    int hstart = ph * strideH - offsetH;
    int wstart = pw * strideW - offsetW;
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
Z
zhangjinchao01 已提交
169 170 171 172 173 174 175 176
    real maxval = -FLT_MAX;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        if (maxval < inputData[h * width + w])
          maxval = inputData[h * width + w];
      }
    }
Q
qijun 已提交
177 178 179
    int tgtIndex = index % (pooledW * pooledH * channels) +
        frameNum * tgtStride;
    tgtData[tgtIndex] = maxval;
Z
zhangjinchao01 已提交
180 181 182
  }
}

183 184 185 186 187 188 189
void hl_maxpool_forward(const int frameCnt, const real* inputData,
                        const int channels,
                        const int height, const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
                        const int paddingH, const int paddingW,
Q
qijun 已提交
190
                        real* tgtData, const int tgtStride) {
191 192 193

  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
194
  dim3 threads(1024, 1);
195 196
  dim3 grid(blocks, 1);

Z
zhangjinchao01 已提交
197 198
  KeMaxPoolForward<<< grid, threads, 0, STREAM_DEFAULT >>>
           (num_kernels, inputData, channels, height, width,
199
           pooledH, pooledW, sizeX, sizeY, strideH, strideW,
Q
qijun 已提交
200
           paddingH, paddingW, tgtData, tgtStride);
Z
zhangjinchao01 已提交
201 202 203
  CHECK_SYNC("hl_maxpool_forward failed");
}

204
__global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
Z
zhangjinchao01 已提交
205
                                  const real* outData, const real* outGrad,
206 207 208 209 210 211 212
                                  const int channels, const int height,
                                  const int width,
                                  const int pooledH, const int pooledW,
                                  const int sizeX, const int sizeY,
                                  const int strideH, const int strideW,
                                  const int padH, const int padW,
                                  real scaleA, real scaleB,
Q
qijun 已提交
213
                                  real* targetGrad, const int outStride) {
214
  int index = blockIdx.x  * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
215 216 217
  if (index < nthreads) {
    // find out the local index
    // find out the local offset
218 219
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
220
    int offsetC = (index / width / height) % channels;
221 222 223 224 225 226

    int frameNum = index / width / height / channels;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
227 228
    real gradient = 0;
    real input = inputData[index];
Q
qijun 已提交
229 230
    outData += (frameNum * outStride + offsetC * pooledH * pooledW);
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
Z
zhangjinchao01 已提交
231 232 233 234 235 236 237 238 239 240 241 242
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        if (input == outData[ph * pooledW + pw]) {
          gradient += outGrad[ph * pooledW + pw];
        }
      }
    }
    targetGrad[index] =
      scaleB * targetGrad[index] + scaleA * gradient;
  }
}

243
void hl_maxpool_backward(const int frameCnt, const real* inputData,
Z
zhangjinchao01 已提交
244
                        const real* outData, const real* outGrad,
245 246 247 248 249 250 251
                        const int channels, const int height,
                        const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
                        const int paddingH, const int paddingW,
                        real scaleA, real scaleB,
Q
qijun 已提交
252
                        real* targetGrad, const int outStride) {
253 254 255

  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
256

257
  KeMaxPoolBackward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
258
           (num_kernels, inputData, outData, outGrad, channels,
259 260 261 262
           height, width, pooledH, pooledW, sizeX, sizeY,
           strideH, strideW,
           paddingH, paddingW,
           scaleA, scaleB,
Q
qijun 已提交
263
           targetGrad, outStride);
Z
zhangjinchao01 已提交
264 265 266
  CHECK_SYNC("hl_maxpool_backward");
}

267 268 269 270 271 272 273
__global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
                                 const int channels,
                                 const int height, const int width,
                                 const int pooledH, const int pooledW,
                                 const int sizeX, const int sizeY,
                                 const int strideH, const int strideW,
                                 const int padH, const int padW,
Q
qijun 已提交
274
                                 real* tgtData, const int tgtStride) {
275
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
276 277 278 279
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
280 281 282 283 284 285 286 287 288 289 290 291
    int frameNum = index / pooledW / pooledH / channels;

    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
    int hend = min(hstart + sizeY, height + padH);
    int wend = min(wstart + sizeX, width + padW);
    int pool_size = (hend - hstart) * (wend - wstart);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    hend = min(hend, height);
    wend = min(wend, width);

Z
zhangjinchao01 已提交
292 293 294 295 296 297 298
    real aveval = 0;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        aveval += inputData[h * width + w];
      }
    }
Q
qijun 已提交
299 300 301
    int tgtIndex = index % (pooledW * pooledH * channels) +
        frameNum * tgtStride;
    tgtData[tgtIndex] = aveval / pool_size;
Z
zhangjinchao01 已提交
302 303 304
  }
}

305 306 307 308 309 310
void hl_avgpool_forward(const int frameCnt, const real* inputData,
                        const int channels,
                        const int height, const int width,
                        const int pooledH, const int pooledW,
                        const int sizeX, const int sizeY,
                        const int strideH, const int strideW,
Q
qijun 已提交
311 312
                        const int paddingH, const int paddingW, 
                        real* tgtData, const int tgtStride) {
313 314 315
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  KeAvgPoolForward<<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
316 317
           (num_kernels, inputData, channels,
           height, width, pooledH, pooledW,
318
           sizeX, sizeY, strideH, strideW,
Q
qijun 已提交
319
           paddingH, paddingW, tgtData, tgtStride);
Z
zhangjinchao01 已提交
320 321 322
  CHECK_SYNC("hl_avgpool_forward failed");
}

323 324 325 326 327 328 329 330
__global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
                                  const int channels, const int height,
                                  const int width,
                                  const int pooledH, const int pooledW,
                                  const int sizeX, const int sizeY,
                                  const int strideH, const int strideW,
                                  const int padH, const int padW,
                                  real scaleA, real scaleB,
Q
qijun 已提交
331
                                  real* tgtGrad, const int outStride) {
332
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
333
  if (index < nthreads) {
334 335
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
336
    int offsetC = (index / width / height) % channels;
337 338 339 340 341 342
    int frameNum = index / width / height / channels;

    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
343
    real gradient = 0;
Q
qijun 已提交
344 345
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);

Z
zhangjinchao01 已提交
346 347 348 349

    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        // figure out the pooling size
350 351 352 353 354
        int hstart = ph * strideH - padH;
        int wstart = pw * strideW - padW;
        int hend = min(hstart + sizeY, height + padH);
        int wend = min(wstart + sizeX, width + padW);
        int poolsize = (hend - hstart) * (wend - wstart);
Z
zhangjinchao01 已提交
355 356 357 358 359 360 361
        gradient += outGrad[ph * pooledW + pw]/poolsize;
      }
    }
    tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
  }
}

362 363 364 365 366 367 368 369
void hl_avgpool_backward(const int frameCnt, const real* outGrad,
                         const int channels,
                         const int height, const int width,
                         const int pooledH, const int pooledW,
                         const int sizeX, const int sizeY,
                         const int strideH, const int strideW,
                         const int paddingH, const int paddingW,
                         real scaleA, real scaleB,
Q
qijun 已提交
370
                         real* backGrad, const int outStride) {
371 372
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
373

374
  KeAvgPoolBackward <<< blocks, 1024, 0, STREAM_DEFAULT >>>
Z
zhangjinchao01 已提交
375
           (num_kernels, outGrad, channels, height, width,
376 377 378 379
           pooledH, pooledW, sizeX, sizeY,
           strideH, strideW,
           paddingH, paddingW,
           scaleA, scaleB,
Q
qijun 已提交
380
           backGrad, outStride);
Z
zhangjinchao01 已提交
381 382 383
  CHECK_SYNC("hl_avgpool_backward failed");
}

L
liaogang 已提交
384
__global__ void KeBilinearInterpFw(const real* in,
L
liaogang 已提交
385 386 387 388 389 390 391 392 393 394 395 396
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
397
  int nthreads = outputH * outputW;                      
L
liaogang 已提交
398
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

    const real* inPos =
      &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];

    // bilinear interpolation
    out[outIdH * outputW + outIdW] =
      h2lambda * (w2lambda * inPos[0]            + w1lambda * inPos[wId]) + 
      h1lambda * (w2lambda * inPos[hId * inImgW] + w1lambda * inPos[hId * inImgW + wId]);
L
liaogang 已提交
425 426 427 428 429 430 431 432 433 434 435 436 437
  }
}

void hl_bilinear_forward(const real* inData,
                         const size_t inImgH,
                         const size_t inImgW,
                         const size_t inputH,
                         const size_t inputW,
                         real* outData,
                         const size_t outImgH,
                         const size_t outImgW,
                         const size_t outputH,
                         const size_t outputW,
L
liaogang 已提交
438 439 440 441
                         const size_t numChannels,
                         const real ratioH,
                         const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
442 443 444
  int blocks = (threadNum + 1024 - 1) / 1024;

  KeBilinearInterpFw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
L
liaogang 已提交
445 446
    inData, inImgH, inImgW, inputH, inputW, outData, outImgH,
    outImgW, outputH, outputW, numChannels, ratioH, ratioW);
L
liaogang 已提交
447 448 449
  CHECK_SYNC("hl_bilinear_forward failed");
}

L
liaogang 已提交
450
__global__ void KeBilinearInterpBw(real* in,
L
liaogang 已提交
451 452 453 454 455 456 457 458 459 460 461 462
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   const real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
463
  int nthreads = outputH * outputW;
L
liaogang 已提交
464
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

    real* inPos =
      &in[outIdH * inputW + channelId * inImgSize + inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
486
    const real* outPos = &out[outIdH * outputW + outIdW];
L
liaogang 已提交
487 488 489 490
    paddle::paddleAtomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[hId * inImgW], h1lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[hId * inImgW + wId], h1lambda * w1lambda * outPos[0]);
L
liaogang 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503
  }
}

void hl_bilinear_backward(real* inGrad,
                          const size_t inImgH,
                          const size_t inImgW,
                          const size_t inputH,
                          const size_t inputW,
                          const real* outGrad,
                          const size_t outImgH,
                          const size_t outImgW,
                          const size_t outputH,
                          const size_t outputW,
L
liaogang 已提交
504 505 506 507
                          const size_t numChannels,
                          const real ratioH,
                          const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
508 509 510
  int blocks = (threadNum + 1024 - 1) / 1024;

  KeBilinearInterpBw<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
L
liaogang 已提交
511 512
    inGrad, inImgH, inImgW, inputH, inputW, outGrad, outImgH,
    outImgW, outputH, outputW, numChannels, ratioH, ratioW);
L
liaogang 已提交
513
  CHECK_SYNC("hl_bilinear_backward failed");
L
liaogang 已提交
514 515
}

516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
__global__ void maxoutFpCompute(size_t nthreads, const real * inData,
                                real * outData, int* idData, 
                                size_t size, size_t featLen, size_t groups) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index < nthreads) {
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
    size_t data_idx = (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
    real max = inData[data_idx];
    int maxId = 0;
    for (size_t g = 1; g < groups; ++g) {
      real tmp = inData[data_idx + g * featLen];
      if (tmp > max) {
        max = tmp;
        maxId = g;
      }
    }
    outData[index] = max;
    idData[index] = maxId;
  }
}

void hl_maxout_forward(const real* inData, real* outData,
                       int* idData, size_t batchSize, size_t size,
                       size_t featLen, size_t groups) {
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  maxoutFpCompute<<< blocks, 1024, 0, STREAM_DEFAULT>>>(
    num_kernels, inData, outData, idData, size, featLen, groups);
  CHECK_SYNC("hl_maxout_forward failed");
}

__global__ void maxoutBpCompute(size_t nthreads, real* inGrad,
                                const real* outGrad, const int* idData,
                                size_t size, size_t featLen, size_t groups) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index < nthreads) {
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
    size_t newIndex = batch_idx * size;
    size_t gradIdx = (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
    (inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i];
  }
}

void hl_maxout_backward(real* inGrad, const real* outGrad,
                        const int* idData, size_t batchSize, size_t size,
                        size_t featLen, size_t groups) {
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  maxoutBpCompute<<< blocks, 1024, 0, STREAM_DEFAULT >>>(
    num_kernels, inGrad, outGrad, idData, size, featLen, groups);
  CHECK_SYNC("hl_maxout_backward failed");
}