hl_cuda_cnn.cu 46.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
C
chengduoZH 已提交
2

Z
zhangjinchao01 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
C
chengduoZH 已提交
6

Z
zhangjinchao01 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
C
chengduoZH 已提交
8

Z
zhangjinchao01 已提交
9 10 11 12 13 14 15 16 17
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <float.h>
#include "hl_base.h"
#include "hl_cnn.h"
L
liaogang 已提交
18
#include "hl_device_functions.cuh"
Z
zhangjinchao01 已提交
19

L
liaogang 已提交
20 21 22 23
__global__ void KeMaxPoolForward(const int nthreads,
                                 const real* inputData,
                                 const int channels,
                                 const int height,
24
                                 const int width,
L
liaogang 已提交
25 26 27 28 29 30 31 32 33
                                 const int pooledH,
                                 const int pooledW,
                                 const int ksizeW,
                                 const int ksizeH,
                                 const int strideH,
                                 const int strideW,
                                 const int offsetH,
                                 const int offsetW,
                                 real* tgtData,
X
xzl 已提交
34
                                 const int tgtStride,
35
                                 real* maskData) {
L
liaogang 已提交
36
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
37 38 39 40
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
41 42 43 44 45 46 47
    int frameNum = index / pooledW / pooledH / channels;
    int hstart = ph * strideH - offsetH;
    int wstart = pw * strideW - offsetW;
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
Z
zhangjinchao01 已提交
48
    real maxval = -FLT_MAX;
X
xzl 已提交
49
    int max_index = -1;
Z
zhangjinchao01 已提交
50 51 52
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
X
xzl 已提交
53 54
        if (maxval < inputData[h * width + w]) {
          max_index = h * width + w;
X
xzl 已提交
55
          maxval = inputData[max_index];
X
xzl 已提交
56
        }
Z
zhangjinchao01 已提交
57 58
      }
    }
L
liaogang 已提交
59 60
    int tgtIndex =
        index % (pooledW * pooledH * channels) + frameNum * tgtStride;
Q
qijun 已提交
61
    tgtData[tgtIndex] = maxval;
62
    if (maskData != NULL) {
X
xzl 已提交
63 64
      maskData[tgtIndex] = max_index;
    }
Z
zhangjinchao01 已提交
65 66 67
  }
}

X
xzl 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
void hl_maxpool_forward(const int frameCnt,
                        const real* inputData,
                        const int channels,
                        const int height,
                        const int width,
                        const int pooledH,
                        const int pooledW,
                        const int sizeX,
                        const int sizeY,
                        const int strideH,
                        const int strideW,
                        const int paddingH,
                        const int paddingW,
                        real* tgtData,
                        const int tgtStride,
83
                        real* maskData) {
X
xzl 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KeMaxPoolForward<<<grid, threads, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         inputData,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         tgtData,
                                                         tgtStride,
104
                                                         maskData);
Z
zhangjinchao01 已提交
105 106 107
  CHECK_SYNC("hl_maxpool_forward failed");
}

L
liaogang 已提交
108 109 110 111 112 113
__global__ void KeMaxPoolBackward(const int nthreads,
                                  const real* inputData,
                                  const real* outData,
                                  const real* outGrad,
                                  const int channels,
                                  const int height,
114
                                  const int width,
L
liaogang 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127
                                  const int pooledH,
                                  const int pooledW,
                                  const int sizeX,
                                  const int sizeY,
                                  const int strideH,
                                  const int strideW,
                                  const int padH,
                                  const int padW,
                                  real scaleA,
                                  real scaleB,
                                  real* targetGrad,
                                  const int outStride) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
128 129 130
  if (index < nthreads) {
    // find out the local index
    // find out the local offset
131 132
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
133
    int offsetC = (index / width / height) % channels;
134 135 136 137 138 139

    int frameNum = index / width / height / channels;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
140 141
    real gradient = 0;
    real input = inputData[index];
Q
qijun 已提交
142 143
    outData += (frameNum * outStride + offsetC * pooledH * pooledW);
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
Z
zhangjinchao01 已提交
144 145 146 147 148 149 150
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        if (input == outData[ph * pooledW + pw]) {
          gradient += outGrad[ph * pooledW + pw];
        }
      }
    }
L
liaogang 已提交
151
    targetGrad[index] = scaleB * targetGrad[index] + scaleA * gradient;
Z
zhangjinchao01 已提交
152 153 154
  }
}

L
liaogang 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
void hl_maxpool_backward(const int frameCnt,
                         const real* inputData,
                         const real* outData,
                         const real* outGrad,
                         const int channels,
                         const int height,
                         const int width,
                         const int pooledH,
                         const int pooledW,
                         const int sizeX,
                         const int sizeY,
                         const int strideH,
                         const int strideW,
                         const int paddingH,
                         const int paddingW,
                         real scaleA,
                         real scaleB,
                         real* targetGrad,
                         const int outStride) {
174 175
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
176

L
liaogang 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
  KeMaxPoolBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         inputData,
                                                         outData,
                                                         outGrad,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         scaleA,
                                                         scaleB,
                                                         targetGrad,
                                                         outStride);
Z
zhangjinchao01 已提交
196 197 198
  CHECK_SYNC("hl_maxpool_backward");
}

L
liaogang 已提交
199 200
__global__ void KeAvgPoolForward(const int nthreads,
                                 const real* inputData,
201
                                 const int channels,
L
liaogang 已提交
202 203 204 205 206 207 208 209 210 211 212 213
                                 const int height,
                                 const int width,
                                 const int pooledH,
                                 const int pooledW,
                                 const int sizeX,
                                 const int sizeY,
                                 const int strideH,
                                 const int strideW,
                                 const int padH,
                                 const int padW,
                                 real* tgtData,
                                 const int tgtStride) {
214
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
215 216 217 218
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
219 220 221 222
    int frameNum = index / pooledW / pooledH / channels;

    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
223 224
    int hend = min(hstart + sizeY, height);
    int wend = min(wstart + sizeX, width);
225 226
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
227
    int pool_size = (hend - hstart) * (wend - wstart);
228

Z
zhangjinchao01 已提交
229 230 231 232 233 234 235
    real aveval = 0;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        aveval += inputData[h * width + w];
      }
    }
L
liaogang 已提交
236 237
    int tgtIndex =
        index % (pooledW * pooledH * channels) + frameNum * tgtStride;
Q
qijun 已提交
238
    tgtData[tgtIndex] = aveval / pool_size;
Z
zhangjinchao01 已提交
239 240 241
  }
}

L
liaogang 已提交
242 243
void hl_avgpool_forward(const int frameCnt,
                        const real* inputData,
244
                        const int channels,
L
liaogang 已提交
245 246 247 248 249 250 251 252 253 254 255 256
                        const int height,
                        const int width,
                        const int pooledH,
                        const int pooledW,
                        const int sizeX,
                        const int sizeY,
                        const int strideH,
                        const int strideW,
                        const int paddingH,
                        const int paddingW,
                        real* tgtData,
                        const int tgtStride) {
257 258
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
  KeAvgPoolForward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                        inputData,
                                                        channels,
                                                        height,
                                                        width,
                                                        pooledH,
                                                        pooledW,
                                                        sizeX,
                                                        sizeY,
                                                        strideH,
                                                        strideW,
                                                        paddingH,
                                                        paddingW,
                                                        tgtData,
                                                        tgtStride);
Z
zhangjinchao01 已提交
274 275 276
  CHECK_SYNC("hl_avgpool_forward failed");
}

L
liaogang 已提交
277 278 279 280
__global__ void KeAvgPoolBackward(const int nthreads,
                                  const real* outGrad,
                                  const int channels,
                                  const int height,
281
                                  const int width,
L
liaogang 已提交
282 283 284 285 286 287 288 289 290 291 292 293
                                  const int pooledH,
                                  const int pooledW,
                                  const int sizeX,
                                  const int sizeY,
                                  const int strideH,
                                  const int strideW,
                                  const int padH,
                                  const int padW,
                                  real scaleA,
                                  real scaleB,
                                  real* tgtGrad,
                                  const int outStride) {
294
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
295
  if (index < nthreads) {
296 297
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
298
    int offsetC = (index / width / height) % channels;
299 300 301 302 303 304
    int frameNum = index / width / height / channels;

    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
305
    real gradient = 0;
Q
qijun 已提交
306 307
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);

Z
zhangjinchao01 已提交
308
    for (int ph = phstart; ph < phend; ++ph) {
309
      int hstart = ph * strideH - padH;
310 311
      int hend = min(hstart + sizeY, height);
      hstart = max(hstart, 0);
Z
zhangjinchao01 已提交
312 313
      for (int pw = pwstart; pw < pwend; ++pw) {
        // figure out the pooling size
314
        int wstart = pw * strideW - padW;
315 316
        int wend = min(wstart + sizeX, width);
        wstart = max(wstart, 0);
317
        int poolsize = (hend - hstart) * (wend - wstart);
L
liaogang 已提交
318
        gradient += outGrad[ph * pooledW + pw] / poolsize;
Z
zhangjinchao01 已提交
319 320 321 322 323 324
      }
    }
    tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
  }
}

L
liaogang 已提交
325 326
void hl_avgpool_backward(const int frameCnt,
                         const real* outGrad,
327
                         const int channels,
L
liaogang 已提交
328 329 330 331 332 333 334 335 336 337 338 339 340 341
                         const int height,
                         const int width,
                         const int pooledH,
                         const int pooledW,
                         const int sizeX,
                         const int sizeY,
                         const int strideH,
                         const int strideW,
                         const int paddingH,
                         const int paddingW,
                         real scaleA,
                         real scaleB,
                         real* backGrad,
                         const int outStride) {
342 343
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
344

L
liaogang 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
  KeAvgPoolBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         outGrad,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         scaleA,
                                                         scaleB,
                                                         backGrad,
                                                         outStride);
Z
zhangjinchao01 已提交
362 363 364
  CHECK_SYNC("hl_avgpool_backward failed");
}

C
chengduoZH 已提交
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
__global__ void KeMaxPool3DForward(const int nthreads,
                                   const real* inputData,
                                   const int channels,
                                   const int depth,
                                   const int height,
                                   const int width,
                                   const int pooledD,
                                   const int pooledH,
                                   const int pooledW,
                                   const int ksizeD,
                                   const int ksizeH,
                                   const int ksizeW,
                                   const int strideD,
                                   const int strideH,
                                   const int strideW,
C
chengduoZH 已提交
380 381 382
                                   const int padD,
                                   const int padH,
                                   const int padW,
C
chengduoZH 已提交
383
                                   real* tgtData,
C
chengduoZH 已提交
384
                                   real* maxPoolIdxData,
C
chengduoZH 已提交
385 386 387 388 389 390 391 392
                                   const int tgtStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int pd = (index / pooledW / pooledH) % pooledD;
    int c = (index / pooledW / pooledH / pooledD) % channels;
    int frameNum = index / pooledW / pooledH / pooledD / channels;
C
chengduoZH 已提交
393 394 395
    int dstart = pd * strideD - padD;
    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
C
chengduoZH 已提交
396 397 398 399 400 401 402
    int dend = min(dstart + ksizeD, depth);
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    real maxval = -FLT_MAX;
C
chengduoZH 已提交
403
    int maxIdx = -1;
C
chengduoZH 已提交
404 405 406 407
    inputData += (frameNum * channels + c) * depth * height * width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
408
          if (maxval < inputData[(d * height + h) * width + w]) {
C
chengduoZH 已提交
409
            maxval = inputData[(d * height + h) * width + w];
C
chengduoZH 已提交
410 411
            maxIdx = (d * height + h) * width + w;
          }
C
chengduoZH 已提交
412 413 414 415 416 417
        }
      }
    }
    int tgtIndex =
        index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride;
    tgtData[tgtIndex] = maxval;
C
chengduoZH 已提交
418
    maxPoolIdxData[tgtIndex] = maxIdx;
C
chengduoZH 已提交
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
  }
}

void hl_maxpool3D_forward(const int frameCnt,
                          const real* inputData,
                          const int channels,
                          const int depth,
                          const int height,
                          const int width,
                          const int pooledD,
                          const int pooledH,
                          const int pooledW,
                          const int sizeZ,
                          const int sizeY,
                          const int sizeX,
                          const int strideD,
                          const int strideH,
                          const int strideW,
C
chengduoZH 已提交
437 438 439
                          const int padD,
                          const int padH,
                          const int padW,
C
chengduoZH 已提交
440
                          real* tgtData,
C
chengduoZH 已提交
441
                          real* maxPoolIdxData,
C
chengduoZH 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
                          const int tgtStride) {
  int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KeMaxPool3DForward<<<grid, threads, 0, STREAM_DEFAULT>>>(num_kernels,
                                                           inputData,
                                                           channels,
                                                           depth,
                                                           height,
                                                           width,
                                                           pooledD,
                                                           pooledH,
                                                           pooledW,
                                                           sizeZ,
                                                           sizeY,
                                                           sizeX,
                                                           strideD,
                                                           strideH,
                                                           strideW,
C
chengduoZH 已提交
463 464 465
                                                           padD,
                                                           padH,
                                                           padW,
C
chengduoZH 已提交
466
                                                           tgtData,
C
chengduoZH 已提交
467
                                                           maxPoolIdxData,
C
chengduoZH 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
                                                           tgtStride);
  CHECK_SYNC("hl_maxpool3D_forward failed");
}

__global__ void KeMaxPool3DBackward(const int nthreads,
                                    const real* outGrad,
                                    const int channels,
                                    const int depth,
                                    const int height,
                                    const int width,
                                    const int pooledD,
                                    const int pooledH,
                                    const int pooledW,
                                    const int sizeZ,
                                    const int sizeY,
                                    const int sizeX,
                                    const int strideD,
                                    const int strideH,
                                    const int strideW,
                                    const int padD,
                                    const int padH,
                                    const int padW,
                                    real scaleA,
                                    real scaleB,
                                    real* targetGrad,
C
chengduoZH 已提交
493
                                    real* maxPoolIdxData,
C
chengduoZH 已提交
494 495 496
                                    const int outStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
497 498 499
    int offsetW = index % width;
    int offsetH = (index / width) % height;
    int offsetD = (index / width / height) % depth;
C
chengduoZH 已提交
500 501 502
    int offsetC = (index / width / height / depth) % channels;
    int frameNum = index / width / height / depth / channels;

C
chengduoZH 已提交
503 504 505 506 507 508 509 510 511
    int pdstart =
        (offsetD + padD < sizeZ) ? 0 : (offsetD + padD - sizeZ) / strideD + 1;
    int phstart =
        (offsetH + padH < sizeY) ? 0 : (offsetH + padH - sizeY) / strideH + 1;
    int pwstart =
        (offsetW + padW < sizeX) ? 0 : (offsetW + padW - sizeX) / strideW + 1;
    int pdend = min((offsetD + padD) / strideD + 1, pooledD);
    int phend = min((offsetH + padH) / strideH + 1, pooledH);
    int pwend = min((offsetW + padW) / strideW + 1, pooledW);
C
chengduoZH 已提交
512 513 514

    real gradient = 0;
    outGrad += ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW);
C
chengduoZH 已提交
515 516
    maxPoolIdxData +=
        ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW);
C
chengduoZH 已提交
517 518 519
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
520 521
          if (((offsetD * height + offsetH) * width + offsetW) ==
              maxPoolIdxData[(pd * pooledH + ph) * pooledW + pw])
C
chengduoZH 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
            gradient += outGrad[(pd * pooledH + ph) * pooledW + pw];
        }
      }
    }
    targetGrad[index] = scaleA * gradient + scaleB * targetGrad[index];
  }
}

void hl_maxpool3D_backward(const int frameCnt,
                           const real* outGrad,
                           const int channels,
                           const int depth,
                           const int height,
                           const int width,
                           const int outputD,
                           const int outputH,
                           const int outputW,
                           const int sizeZ,
                           const int sizeY,
                           const int sizeX,
                           const int strideD,
                           const int strideH,
                           const int strideW,
                           const int paddingD,
                           const int paddingH,
                           const int paddingW,
                           real scaleA,
                           real scaleB,
                           real* targetGrad,
C
chengduoZH 已提交
551
                           real* maxPoolIdxData,
C
chengduoZH 已提交
552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
                           const int outStride) {
  int num_kernels = depth * height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;

  KeMaxPool3DBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                           outGrad,
                                                           channels,
                                                           depth,
                                                           height,
                                                           width,
                                                           outputD,
                                                           outputH,
                                                           outputW,
                                                           sizeZ,
                                                           sizeY,
                                                           sizeX,
                                                           strideD,
                                                           strideH,
                                                           strideW,
                                                           paddingD,
                                                           paddingH,
                                                           paddingW,
                                                           scaleA,
                                                           scaleB,
                                                           targetGrad,
C
chengduoZH 已提交
577
                                                           maxPoolIdxData,
C
chengduoZH 已提交
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
                                                           outStride);
  CHECK_SYNC("hl_maxpool3D_backward");
}

__global__ void KeAvgPool3DForward(const int nthreads,
                                   const real* inputData,
                                   const int channels,
                                   const int depth,
                                   const int height,
                                   const int width,
                                   const int pooledD,
                                   const int pooledH,
                                   const int pooledW,
                                   const int sizeZ,
                                   const int sizeY,
                                   const int sizeX,
                                   const int strideD,
                                   const int strideH,
                                   const int strideW,
                                   const int padD,
                                   const int padH,
                                   const int padW,
                                   real* tgtData,
                                   const int tgtStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int pd = (index / pooledW / pooledH) % pooledD;
    int c = (index / pooledW / pooledH / pooledD) % channels;
    int frameNum = index / pooledW / pooledH / pooledD / channels;
    int dstart = pd * strideD - padD;
    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
612 613 614
    int dend = min(dstart + sizeZ, depth);
    int hend = min(hstart + sizeY, height);
    int wend = min(wstart + sizeX, width);
C
chengduoZH 已提交
615 616 617
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
618
    int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
C
chengduoZH 已提交
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720

    real aveval = 0;
    inputData += (frameNum * channels + c) * depth * height * width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          aveval += inputData[(d * height + h) * width + w];
        }
      }
    }
    int tgtIndex =
        index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride;
    tgtData[tgtIndex] = aveval / pool_size;
  }
}

void hl_avgpool3D_forward(const int frameCnt,
                          const real* inputData,
                          const int channels,
                          const int depth,
                          const int height,
                          const int width,
                          const int pooledD,
                          const int pooledH,
                          const int pooledW,
                          const int sizeZ,
                          const int sizeY,
                          const int sizeX,
                          const int strideD,
                          const int strideH,
                          const int strideW,
                          const int paddingD,
                          const int paddingH,
                          const int paddingW,
                          real* tgtData,
                          const int tgtStride) {
  int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  KeAvgPool3DForward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                          inputData,
                                                          channels,
                                                          depth,
                                                          height,
                                                          width,
                                                          pooledD,
                                                          pooledH,
                                                          pooledW,
                                                          sizeZ,
                                                          sizeY,
                                                          sizeX,
                                                          strideD,
                                                          strideH,
                                                          strideW,
                                                          paddingD,
                                                          paddingH,
                                                          paddingW,
                                                          tgtData,
                                                          tgtStride);
  CHECK_SYNC("hl_avgpool3D_forward failed");
}

__global__ void KeAvgPool3DBackward(const int nthreads,
                                    const real* outGrad,
                                    const int channels,
                                    const int depth,
                                    const int height,
                                    const int width,
                                    const int pooledD,
                                    const int pooledH,
                                    const int pooledW,
                                    const int sizeZ,
                                    const int sizeY,
                                    const int sizeX,
                                    const int strideD,
                                    const int strideH,
                                    const int strideW,
                                    const int padD,
                                    const int padH,
                                    const int padW,
                                    real scaleA,
                                    real scaleB,
                                    real* tgtGrad,
                                    const int outStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
    int offsetD = (index / width / height) % depth + padD;
    int offsetC = (index / width / height / depth) % channels;
    int frameNum = index / width / height / depth / channels;

    int pdstart = (offsetD < sizeZ) ? 0 : (offsetD - sizeZ) / strideD + 1;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int pdend = min(offsetD / strideD + 1, pooledD);
    int phend = min(offsetH / strideH + 1, pooledH);
    int pwend = min(offsetW / strideW + 1, pooledW);

    real gradient = 0;
    outGrad += (frameNum * channels + offsetC) * pooledD * pooledH * pooledW;

    for (int pd = pdstart; pd < pdend; ++pd) {
721
      int dstart = pd * strideD - padD;
722 723
      int dend = min(dstart + sizeZ, depth);
      dstart = max(dstart, 0);
C
chengduoZH 已提交
724
      for (int ph = phstart; ph < phend; ++ph) {
725
        int hstart = ph * strideH - padH;
726 727
        int hend = min(hstart + sizeY, height);
        hstart = max(hstart, 0);
C
chengduoZH 已提交
728 729 730
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int wstart = pw * strideW - padW;
731 732
          int wend = min(wstart + sizeX, width);
          wstart = max(wstart, 0);
C
chengduoZH 已提交
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791
          int poolsize = (dend - dstart) * (hend - hstart) * (wend - wstart);
          gradient += outGrad[(pd * pooledH + ph) * pooledW + pw] / poolsize;
        }
      }
    }
    tgtGrad[index] = scaleA * gradient + scaleB * tgtGrad[index];
  }
}

void hl_avgpool3D_backward(const int frameCnt,
                           const real* outGrad,
                           const int channels,
                           const int depth,
                           const int height,
                           const int width,
                           const int outputD,
                           const int outputH,
                           const int outputW,
                           const int sizeZ,
                           const int sizeY,
                           const int sizeX,
                           const int strideD,
                           const int strideH,
                           const int strideW,
                           int paddingD,
                           int paddingH,
                           int paddingW,
                           real scaleA,
                           real scaleB,
                           real* backGrad,
                           const int outStride) {
  int num_kernels = depth * height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;

  KeAvgPool3DBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                           outGrad,
                                                           channels,
                                                           depth,
                                                           height,
                                                           width,
                                                           outputD,
                                                           outputH,
                                                           outputW,
                                                           sizeZ,
                                                           sizeY,
                                                           sizeX,
                                                           strideD,
                                                           strideH,
                                                           strideW,
                                                           paddingD,
                                                           paddingH,
                                                           paddingW,
                                                           scaleA,
                                                           scaleB,
                                                           backGrad,
                                                           outStride);
  CHECK_SYNC("hl_avgpool3D_backward failed");
}

L
liaogang 已提交
792
__global__ void KeBilinearInterpFw(const real* in,
L
liaogang 已提交
793 794 795 796 797 798 799 800 801 802 803 804
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
805
  int nthreads = outputH * outputW;
L
liaogang 已提交
806
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

L
liaogang 已提交
826 827
    const real* inPos = &in[outIdH * inputW + channelId * inImgSize +
                            inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
828 829 830

    // bilinear interpolation
    out[outIdH * outputW + outIdW] =
L
liaogang 已提交
831 832 833
        h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
        h1lambda * (w2lambda * inPos[hId * inImgW] +
                    w1lambda * inPos[hId * inImgW + wId]);
L
liaogang 已提交
834 835 836 837 838 839 840 841 842 843 844 845 846
  }
}

void hl_bilinear_forward(const real* inData,
                         const size_t inImgH,
                         const size_t inImgW,
                         const size_t inputH,
                         const size_t inputW,
                         real* outData,
                         const size_t outImgH,
                         const size_t outImgW,
                         const size_t outputH,
                         const size_t outputW,
L
liaogang 已提交
847 848 849 850
                         const size_t numChannels,
                         const real ratioH,
                         const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
851 852
  int blocks = (threadNum + 1024 - 1) / 1024;

L
liaogang 已提交
853 854 855 856 857 858 859 860 861 862 863 864 865
  KeBilinearInterpFw<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inData,
                                                          inImgH,
                                                          inImgW,
                                                          inputH,
                                                          inputW,
                                                          outData,
                                                          outImgH,
                                                          outImgW,
                                                          outputH,
                                                          outputW,
                                                          numChannels,
                                                          ratioH,
                                                          ratioW);
L
liaogang 已提交
866 867 868
  CHECK_SYNC("hl_bilinear_forward failed");
}

L
liaogang 已提交
869
__global__ void KeBilinearInterpBw(real* in,
L
liaogang 已提交
870 871 872 873 874 875 876 877 878 879 880 881
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   const real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
882
  int nthreads = outputH * outputW;
L
liaogang 已提交
883
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

L
liaogang 已提交
903 904
    real* inPos = &in[outIdH * inputW + channelId * inImgSize +
                      inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
905
    const real* outPos = &out[outIdH * outputW + outIdW];
L
liaogang 已提交
906 907
    paddle::paddleAtomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
L
liaogang 已提交
908 909 910 911
    paddle::paddleAtomicAdd(&inPos[hId * inImgW],
                            h1lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[hId * inImgW + wId],
                            h1lambda * w1lambda * outPos[0]);
L
liaogang 已提交
912 913 914 915 916 917 918 919 920 921 922 923 924
  }
}

void hl_bilinear_backward(real* inGrad,
                          const size_t inImgH,
                          const size_t inImgW,
                          const size_t inputH,
                          const size_t inputW,
                          const real* outGrad,
                          const size_t outImgH,
                          const size_t outImgW,
                          const size_t outputH,
                          const size_t outputW,
L
liaogang 已提交
925 926 927 928
                          const size_t numChannels,
                          const real ratioH,
                          const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
929 930
  int blocks = (threadNum + 1024 - 1) / 1024;

L
liaogang 已提交
931 932 933 934 935 936 937 938 939 940 941 942 943
  KeBilinearInterpBw<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inGrad,
                                                          inImgH,
                                                          inImgW,
                                                          inputH,
                                                          inputW,
                                                          outGrad,
                                                          outImgH,
                                                          outImgW,
                                                          outputH,
                                                          outputW,
                                                          numChannels,
                                                          ratioH,
                                                          ratioW);
L
liaogang 已提交
944
  CHECK_SYNC("hl_bilinear_backward failed");
L
liaogang 已提交
945 946
}

L
liaogang 已提交
947 948 949 950 951 952 953
__global__ void maxoutFpCompute(size_t nthreads,
                                const real* inData,
                                real* outData,
                                int* idData,
                                size_t size,
                                size_t featLen,
                                size_t groups) {
954
  int index = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
955
  if (index < nthreads) {
956 957 958 959
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
L
liaogang 已提交
960 961
    size_t data_idx =
        (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
962 963 964 965 966 967 968 969 970 971 972 973 974 975
    real max = inData[data_idx];
    int maxId = 0;
    for (size_t g = 1; g < groups; ++g) {
      real tmp = inData[data_idx + g * featLen];
      if (tmp > max) {
        max = tmp;
        maxId = g;
      }
    }
    outData[index] = max;
    idData[index] = maxId;
  }
}

L
liaogang 已提交
976 977 978 979 980 981 982
void hl_maxout_forward(const real* inData,
                       real* outData,
                       int* idData,
                       size_t batchSize,
                       size_t size,
                       size_t featLen,
                       size_t groups) {
983 984
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
985 986
  maxoutFpCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(
      num_kernels, inData, outData, idData, size, featLen, groups);
987 988 989
  CHECK_SYNC("hl_maxout_forward failed");
}

L
liaogang 已提交
990 991 992 993 994 995 996
__global__ void maxoutBpCompute(size_t nthreads,
                                real* inGrad,
                                const real* outGrad,
                                const int* idData,
                                size_t size,
                                size_t featLen,
                                size_t groups) {
997
  int index = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
998
  if (index < nthreads) {
999 1000 1001 1002 1003
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
    size_t newIndex = batch_idx * size;
L
liaogang 已提交
1004 1005
    size_t gradIdx =
        (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
1006 1007 1008 1009
    (inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i];
  }
}

L
liaogang 已提交
1010 1011 1012 1013 1014 1015 1016
void hl_maxout_backward(real* inGrad,
                        const real* outGrad,
                        const int* idData,
                        size_t batchSize,
                        size_t size,
                        size_t featLen,
                        size_t groups) {
1017 1018
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
1019 1020
  maxoutBpCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(
      num_kernels, inGrad, outGrad, idData, size, featLen, groups);
1021 1022
  CHECK_SYNC("hl_maxout_backward failed");
}