hl_cuda_cnn.cu 45.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <float.h>
#include "hl_base.h"
#include "hl_cnn.h"
L
liaogang 已提交
15
#include "hl_device_functions.cuh"
Z
zhangjinchao01 已提交
16

L
liaogang 已提交
17 18 19 20
__global__ void KeMaxPoolForward(const int nthreads,
                                 const real* inputData,
                                 const int channels,
                                 const int height,
21
                                 const int width,
L
liaogang 已提交
22 23 24 25 26 27 28 29 30 31 32
                                 const int pooledH,
                                 const int pooledW,
                                 const int ksizeW,
                                 const int ksizeH,
                                 const int strideH,
                                 const int strideW,
                                 const int offsetH,
                                 const int offsetW,
                                 real* tgtData,
                                 const int tgtStride) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
33 34 35 36
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
37 38 39 40 41 42 43
    int frameNum = index / pooledW / pooledH / channels;
    int hstart = ph * strideH - offsetH;
    int wstart = pw * strideW - offsetW;
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
Z
zhangjinchao01 已提交
44 45 46 47 48 49 50 51
    real maxval = -FLT_MAX;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        if (maxval < inputData[h * width + w])
          maxval = inputData[h * width + w];
      }
    }
L
liaogang 已提交
52 53
    int tgtIndex =
        index % (pooledW * pooledH * channels) + frameNum * tgtStride;
Q
qijun 已提交
54
    tgtData[tgtIndex] = maxval;
Z
zhangjinchao01 已提交
55 56 57
  }
}

L
liaogang 已提交
58 59
void hl_maxpool_forward(const int frameCnt,
                        const real* inputData,
60
                        const int channels,
L
liaogang 已提交
61 62 63 64 65 66 67 68 69 70 71 72
                        const int height,
                        const int width,
                        const int pooledH,
                        const int pooledW,
                        const int sizeX,
                        const int sizeY,
                        const int strideH,
                        const int strideW,
                        const int paddingH,
                        const int paddingW,
                        real* tgtData,
                        const int tgtStride) {
73 74
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
75
  dim3 threads(1024, 1);
76 77
  dim3 grid(blocks, 1);

L
liaogang 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  KeMaxPoolForward<<<grid, threads, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         inputData,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         tgtData,
                                                         tgtStride);
Z
zhangjinchao01 已提交
93 94 95
  CHECK_SYNC("hl_maxpool_forward failed");
}

L
liaogang 已提交
96 97 98 99 100 101
__global__ void KeMaxPoolBackward(const int nthreads,
                                  const real* inputData,
                                  const real* outData,
                                  const real* outGrad,
                                  const int channels,
                                  const int height,
102
                                  const int width,
L
liaogang 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115
                                  const int pooledH,
                                  const int pooledW,
                                  const int sizeX,
                                  const int sizeY,
                                  const int strideH,
                                  const int strideW,
                                  const int padH,
                                  const int padW,
                                  real scaleA,
                                  real scaleB,
                                  real* targetGrad,
                                  const int outStride) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
116 117 118
  if (index < nthreads) {
    // find out the local index
    // find out the local offset
119 120
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
121
    int offsetC = (index / width / height) % channels;
122 123 124 125 126 127

    int frameNum = index / width / height / channels;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
128 129
    real gradient = 0;
    real input = inputData[index];
Q
qijun 已提交
130 131
    outData += (frameNum * outStride + offsetC * pooledH * pooledW);
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
Z
zhangjinchao01 已提交
132 133 134 135 136 137 138
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        if (input == outData[ph * pooledW + pw]) {
          gradient += outGrad[ph * pooledW + pw];
        }
      }
    }
L
liaogang 已提交
139
    targetGrad[index] = scaleB * targetGrad[index] + scaleA * gradient;
Z
zhangjinchao01 已提交
140 141 142
  }
}

L
liaogang 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
void hl_maxpool_backward(const int frameCnt,
                         const real* inputData,
                         const real* outData,
                         const real* outGrad,
                         const int channels,
                         const int height,
                         const int width,
                         const int pooledH,
                         const int pooledW,
                         const int sizeX,
                         const int sizeY,
                         const int strideH,
                         const int strideW,
                         const int paddingH,
                         const int paddingW,
                         real scaleA,
                         real scaleB,
                         real* targetGrad,
                         const int outStride) {
162 163
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
164

L
liaogang 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
  KeMaxPoolBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         inputData,
                                                         outData,
                                                         outGrad,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         scaleA,
                                                         scaleB,
                                                         targetGrad,
                                                         outStride);
Z
zhangjinchao01 已提交
184 185 186
  CHECK_SYNC("hl_maxpool_backward");
}

L
liaogang 已提交
187 188
__global__ void KeAvgPoolForward(const int nthreads,
                                 const real* inputData,
189
                                 const int channels,
L
liaogang 已提交
190 191 192 193 194 195 196 197 198 199 200 201
                                 const int height,
                                 const int width,
                                 const int pooledH,
                                 const int pooledW,
                                 const int sizeX,
                                 const int sizeY,
                                 const int strideH,
                                 const int strideW,
                                 const int padH,
                                 const int padW,
                                 real* tgtData,
                                 const int tgtStride) {
202
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
203 204 205 206
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
207 208 209 210 211 212 213 214 215 216 217 218
    int frameNum = index / pooledW / pooledH / channels;

    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
    int hend = min(hstart + sizeY, height + padH);
    int wend = min(wstart + sizeX, width + padW);
    int pool_size = (hend - hstart) * (wend - wstart);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    hend = min(hend, height);
    wend = min(wend, width);

Z
zhangjinchao01 已提交
219 220 221 222 223 224 225
    real aveval = 0;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        aveval += inputData[h * width + w];
      }
    }
L
liaogang 已提交
226 227
    int tgtIndex =
        index % (pooledW * pooledH * channels) + frameNum * tgtStride;
Q
qijun 已提交
228
    tgtData[tgtIndex] = aveval / pool_size;
Z
zhangjinchao01 已提交
229 230 231
  }
}

L
liaogang 已提交
232 233
void hl_avgpool_forward(const int frameCnt,
                        const real* inputData,
234
                        const int channels,
L
liaogang 已提交
235 236 237 238 239 240 241 242 243 244 245 246
                        const int height,
                        const int width,
                        const int pooledH,
                        const int pooledW,
                        const int sizeX,
                        const int sizeY,
                        const int strideH,
                        const int strideW,
                        const int paddingH,
                        const int paddingW,
                        real* tgtData,
                        const int tgtStride) {
247 248
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
  KeAvgPoolForward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                        inputData,
                                                        channels,
                                                        height,
                                                        width,
                                                        pooledH,
                                                        pooledW,
                                                        sizeX,
                                                        sizeY,
                                                        strideH,
                                                        strideW,
                                                        paddingH,
                                                        paddingW,
                                                        tgtData,
                                                        tgtStride);
Z
zhangjinchao01 已提交
264 265 266
  CHECK_SYNC("hl_avgpool_forward failed");
}

L
liaogang 已提交
267 268 269 270
__global__ void KeAvgPoolBackward(const int nthreads,
                                  const real* outGrad,
                                  const int channels,
                                  const int height,
271
                                  const int width,
L
liaogang 已提交
272 273 274 275 276 277 278 279 280 281 282 283
                                  const int pooledH,
                                  const int pooledW,
                                  const int sizeX,
                                  const int sizeY,
                                  const int strideH,
                                  const int strideW,
                                  const int padH,
                                  const int padW,
                                  real scaleA,
                                  real scaleB,
                                  real* tgtGrad,
                                  const int outStride) {
284
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
285
  if (index < nthreads) {
286 287
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
288
    int offsetC = (index / width / height) % channels;
289 290 291 292 293 294
    int frameNum = index / width / height / channels;

    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
295
    real gradient = 0;
Q
qijun 已提交
296 297
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);

Z
zhangjinchao01 已提交
298 299 300
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        // figure out the pooling size
301 302 303 304 305
        int hstart = ph * strideH - padH;
        int wstart = pw * strideW - padW;
        int hend = min(hstart + sizeY, height + padH);
        int wend = min(wstart + sizeX, width + padW);
        int poolsize = (hend - hstart) * (wend - wstart);
L
liaogang 已提交
306
        gradient += outGrad[ph * pooledW + pw] / poolsize;
Z
zhangjinchao01 已提交
307 308 309 310 311 312
      }
    }
    tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
  }
}

L
liaogang 已提交
313 314
void hl_avgpool_backward(const int frameCnt,
                         const real* outGrad,
315
                         const int channels,
L
liaogang 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329
                         const int height,
                         const int width,
                         const int pooledH,
                         const int pooledW,
                         const int sizeX,
                         const int sizeY,
                         const int strideH,
                         const int strideW,
                         const int paddingH,
                         const int paddingW,
                         real scaleA,
                         real scaleB,
                         real* backGrad,
                         const int outStride) {
330 331
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
332

L
liaogang 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
  KeAvgPoolBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         outGrad,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         scaleA,
                                                         scaleB,
                                                         backGrad,
                                                         outStride);
Z
zhangjinchao01 已提交
350 351 352
  CHECK_SYNC("hl_avgpool_backward failed");
}

C
chengduoZH 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776
/////////////////
__global__ void KeMaxPool3DForward(const int nthreads,
                                   const real* inputData,
                                   const int channels,
                                   const int depth,
                                   const int height,
                                   const int width,
                                   const int pooledD,
                                   const int pooledH,
                                   const int pooledW,
                                   const int ksizeD,
                                   const int ksizeH,
                                   const int ksizeW,
                                   const int strideD,
                                   const int strideH,
                                   const int strideW,
                                   const int offsetD,
                                   const int offsetH,
                                   const int offsetW,
                                   real* tgtData,
                                   const int tgtStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int pd = (index / pooledW / pooledH) % pooledD;
    int c = (index / pooledW / pooledH / pooledD) % channels;
    int frameNum = index / pooledW / pooledH / pooledD / channels;
    int dstart = pd * strideD - offsetD;
    int hstart = ph * strideH - offsetH;
    int wstart = pw * strideW - offsetW;
    int dend = min(dstart + ksizeD, depth);
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    real maxval = -FLT_MAX;
    inputData += (frameNum * channels + c) * depth * height * width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (maxval < inputData[(d * height + h) * width + w])
            maxval = inputData[(d * height + h) * width + w];
        }
      }
    }
    int tgtIndex =
        index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride;
    tgtData[tgtIndex] = maxval;
  }
}

void hl_maxpool3D_forward(const int frameCnt,
                          const real* inputData,
                          const int channels,
                          const int depth,
                          const int height,
                          const int width,
                          const int pooledD,
                          const int pooledH,
                          const int pooledW,
                          const int sizeZ,
                          const int sizeY,
                          const int sizeX,
                          const int strideD,
                          const int strideH,
                          const int strideW,
                          const int paddingD,
                          const int paddingH,
                          const int paddingW,
                          real* tgtData,
                          const int tgtStride) {
  int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KeMaxPool3DForward<<<grid, threads, 0, STREAM_DEFAULT>>>(num_kernels,
                                                           inputData,
                                                           channels,
                                                           depth,
                                                           height,
                                                           width,
                                                           pooledD,
                                                           pooledH,
                                                           pooledW,
                                                           sizeZ,
                                                           sizeY,
                                                           sizeX,
                                                           strideD,
                                                           strideH,
                                                           strideW,
                                                           paddingD,
                                                           paddingH,
                                                           paddingW,
                                                           tgtData,
                                                           tgtStride);
  CHECK_SYNC("hl_maxpool3D_forward failed");
}

__global__ void KeMaxPool3DBackward(const int nthreads,
                                    const real* inputData,
                                    const real* outData,
                                    const real* outGrad,
                                    const int channels,
                                    const int depth,
                                    const int height,
                                    const int width,
                                    const int pooledD,
                                    const int pooledH,
                                    const int pooledW,
                                    const int sizeZ,
                                    const int sizeY,
                                    const int sizeX,
                                    const int strideD,
                                    const int strideH,
                                    const int strideW,
                                    const int padD,
                                    const int padH,
                                    const int padW,
                                    real scaleA,
                                    real scaleB,
                                    real* targetGrad,
                                    const int outStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    // find out the local index
    // find out the local offset
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
    int offsetD = (index / width / height) % depth + padD;
    int offsetC = (index / width / height / depth) % channels;
    int frameNum = index / width / height / depth / channels;

    int pdstart = (offsetD < sizeZ) ? 0 : (offsetD - sizeZ) / strideD + 1;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int pdend = min(offsetD / strideD + 1, pooledD);
    int phend = min(offsetH / strideH + 1, pooledH);
    int pwend = min(offsetW / strideW + 1, pooledW);

    real gradient = 0;
    real input = inputData[index];

    outData += ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW);
    outGrad += ((frameNum * channels + offsetC) * pooledD * pooledH * pooledW);
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          if (input == outData[(pd * pooledH + ph) * pooledW + pw])
            gradient += outGrad[(pd * pooledH + ph) * pooledW + pw];
        }
      }
    }
    targetGrad[index] = scaleA * gradient + scaleB * targetGrad[index];
  }
}

void hl_maxpool3D_backward(const int frameCnt,
                           const real* inputData,
                           const real* outData,
                           const real* outGrad,
                           const int channels,
                           const int depth,
                           const int height,
                           const int width,
                           const int outputD,
                           const int outputH,
                           const int outputW,
                           const int sizeZ,
                           const int sizeY,
                           const int sizeX,
                           const int strideD,
                           const int strideH,
                           const int strideW,
                           const int paddingD,
                           const int paddingH,
                           const int paddingW,
                           real scaleA,
                           real scaleB,
                           real* targetGrad,
                           const int outStride) {
  int num_kernels = depth * height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;

  KeMaxPool3DBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                           inputData,
                                                           outData,
                                                           outGrad,
                                                           channels,
                                                           depth,
                                                           height,
                                                           width,
                                                           outputD,
                                                           outputH,
                                                           outputW,
                                                           sizeZ,
                                                           sizeY,
                                                           sizeX,
                                                           strideD,
                                                           strideH,
                                                           strideW,
                                                           paddingD,
                                                           paddingH,
                                                           paddingW,
                                                           scaleA,
                                                           scaleB,
                                                           targetGrad,
                                                           outStride);
  CHECK_SYNC("hl_maxpool3D_backward");
}

__global__ void KeAvgPool3DForward(const int nthreads,
                                   const real* inputData,
                                   const int channels,
                                   const int depth,
                                   const int height,
                                   const int width,
                                   const int pooledD,
                                   const int pooledH,
                                   const int pooledW,
                                   const int sizeZ,
                                   const int sizeY,
                                   const int sizeX,
                                   const int strideD,
                                   const int strideH,
                                   const int strideW,
                                   const int padD,
                                   const int padH,
                                   const int padW,
                                   real* tgtData,
                                   const int tgtStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int pd = (index / pooledW / pooledH) % pooledD;
    int c = (index / pooledW / pooledH / pooledD) % channels;
    int frameNum = index / pooledW / pooledH / pooledD / channels;
    int dstart = pd * strideD - padD;
    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
    int dend = min(dstart + sizeZ, depth + padD);
    int hend = min(hstart + sizeY, height + padH);
    int wend = min(wstart + sizeX, width + padW);
    int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    dend = min(dend, depth);
    hend = min(hend, height);
    wend = min(wend, width);

    real aveval = 0;
    inputData += (frameNum * channels + c) * depth * height * width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          aveval += inputData[(d * height + h) * width + w];
        }
      }
    }
    int tgtIndex =
        index % (pooledW * pooledH * pooledD * channels) + frameNum * tgtStride;
    tgtData[tgtIndex] = aveval / pool_size;
  }
}

void hl_avgpool3D_forward(const int frameCnt,
                          const real* inputData,
                          const int channels,
                          const int depth,
                          const int height,
                          const int width,
                          const int pooledD,
                          const int pooledH,
                          const int pooledW,
                          const int sizeZ,
                          const int sizeY,
                          const int sizeX,
                          const int strideD,
                          const int strideH,
                          const int strideW,
                          const int paddingD,
                          const int paddingH,
                          const int paddingW,
                          real* tgtData,
                          const int tgtStride) {
  int num_kernels = pooledD * pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
  KeAvgPool3DForward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                          inputData,
                                                          channels,
                                                          depth,
                                                          height,
                                                          width,
                                                          pooledD,
                                                          pooledH,
                                                          pooledW,
                                                          sizeZ,
                                                          sizeY,
                                                          sizeX,
                                                          strideD,
                                                          strideH,
                                                          strideW,
                                                          paddingD,
                                                          paddingH,
                                                          paddingW,
                                                          tgtData,
                                                          tgtStride);
  CHECK_SYNC("hl_avgpool3D_forward failed");
}

__global__ void KeAvgPool3DBackward(const int nthreads,
                                    const real* outGrad,
                                    const int channels,
                                    const int depth,
                                    const int height,
                                    const int width,
                                    const int pooledD,
                                    const int pooledH,
                                    const int pooledW,
                                    const int sizeZ,
                                    const int sizeY,
                                    const int sizeX,
                                    const int strideD,
                                    const int strideH,
                                    const int strideW,
                                    const int padD,
                                    const int padH,
                                    const int padW,
                                    real scaleA,
                                    real scaleB,
                                    real* tgtGrad,
                                    const int outStride) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
    int offsetD = (index / width / height) % depth + padD;
    int offsetC = (index / width / height / depth) % channels;
    int frameNum = index / width / height / depth / channels;

    int pdstart = (offsetD < sizeZ) ? 0 : (offsetD - sizeZ) / strideD + 1;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int pdend = min(offsetD / strideD + 1, pooledD);
    int phend = min(offsetH / strideH + 1, pooledH);
    int pwend = min(offsetW / strideW + 1, pooledW);

    real gradient = 0;
    outGrad += (frameNum * channels + offsetC) * pooledD * pooledH * pooledW;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * strideD - padD;
          int hstart = ph * strideH - padH;
          int wstart = pw * strideW - padW;
          int dend = min(dstart + sizeZ, depth + padD);
          int hend = min(hstart + sizeY, height + padH);
          int wend = min(wstart + sizeX, width + padW);
          int poolsize = (dend - dstart) * (hend - hstart) * (wend - wstart);
          gradient += outGrad[(pd * pooledH + ph) * pooledW + pw] / poolsize;
        }
      }
    }
    tgtGrad[index] = scaleA * gradient + scaleB * tgtGrad[index];
  }
}

void hl_avgpool3D_backward(const int frameCnt,
                           const real* outGrad,
                           const int channels,
                           const int depth,
                           const int height,
                           const int width,
                           const int outputD,
                           const int outputH,
                           const int outputW,
                           const int sizeZ,
                           const int sizeY,
                           const int sizeX,
                           const int strideD,
                           const int strideH,
                           const int strideW,
                           int paddingD,
                           int paddingH,
                           int paddingW,
                           real scaleA,
                           real scaleB,
                           real* backGrad,
                           const int outStride) {
  int num_kernels = depth * height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;

  KeAvgPool3DBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                           outGrad,
                                                           channels,
                                                           depth,
                                                           height,
                                                           width,
                                                           outputD,
                                                           outputH,
                                                           outputW,
                                                           sizeZ,
                                                           sizeY,
                                                           sizeX,
                                                           strideD,
                                                           strideH,
                                                           strideW,
                                                           paddingD,
                                                           paddingH,
                                                           paddingW,
                                                           scaleA,
                                                           scaleB,
                                                           backGrad,
                                                           outStride);
  CHECK_SYNC("hl_avgpool3D_backward failed");
}
/////////////////

L
liaogang 已提交
777
__global__ void KeBilinearInterpFw(const real* in,
L
liaogang 已提交
778 779 780 781 782 783 784 785 786 787 788 789
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
790
  int nthreads = outputH * outputW;
L
liaogang 已提交
791
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

L
liaogang 已提交
811 812
    const real* inPos = &in[outIdH * inputW + channelId * inImgSize +
                            inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
813 814 815

    // bilinear interpolation
    out[outIdH * outputW + outIdW] =
L
liaogang 已提交
816 817 818
        h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
        h1lambda * (w2lambda * inPos[hId * inImgW] +
                    w1lambda * inPos[hId * inImgW + wId]);
L
liaogang 已提交
819 820 821 822 823 824 825 826 827 828 829 830 831
  }
}

void hl_bilinear_forward(const real* inData,
                         const size_t inImgH,
                         const size_t inImgW,
                         const size_t inputH,
                         const size_t inputW,
                         real* outData,
                         const size_t outImgH,
                         const size_t outImgW,
                         const size_t outputH,
                         const size_t outputW,
L
liaogang 已提交
832 833 834 835
                         const size_t numChannels,
                         const real ratioH,
                         const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
836 837
  int blocks = (threadNum + 1024 - 1) / 1024;

L
liaogang 已提交
838 839 840 841 842 843 844 845 846 847 848 849 850
  KeBilinearInterpFw<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inData,
                                                          inImgH,
                                                          inImgW,
                                                          inputH,
                                                          inputW,
                                                          outData,
                                                          outImgH,
                                                          outImgW,
                                                          outputH,
                                                          outputW,
                                                          numChannels,
                                                          ratioH,
                                                          ratioW);
L
liaogang 已提交
851 852 853
  CHECK_SYNC("hl_bilinear_forward failed");
}

L
liaogang 已提交
854
__global__ void KeBilinearInterpBw(real* in,
L
liaogang 已提交
855 856 857 858 859 860 861 862 863 864 865 866
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   const real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
867
  int nthreads = outputH * outputW;
L
liaogang 已提交
868
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

L
liaogang 已提交
888 889
    real* inPos = &in[outIdH * inputW + channelId * inImgSize +
                      inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
890
    const real* outPos = &out[outIdH * outputW + outIdW];
L
liaogang 已提交
891 892
    paddle::paddleAtomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
L
liaogang 已提交
893 894 895 896
    paddle::paddleAtomicAdd(&inPos[hId * inImgW],
                            h1lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[hId * inImgW + wId],
                            h1lambda * w1lambda * outPos[0]);
L
liaogang 已提交
897 898 899 900 901 902 903 904 905 906 907 908 909
  }
}

void hl_bilinear_backward(real* inGrad,
                          const size_t inImgH,
                          const size_t inImgW,
                          const size_t inputH,
                          const size_t inputW,
                          const real* outGrad,
                          const size_t outImgH,
                          const size_t outImgW,
                          const size_t outputH,
                          const size_t outputW,
L
liaogang 已提交
910 911 912 913
                          const size_t numChannels,
                          const real ratioH,
                          const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
914 915
  int blocks = (threadNum + 1024 - 1) / 1024;

L
liaogang 已提交
916 917 918 919 920 921 922 923 924 925 926 927 928
  KeBilinearInterpBw<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inGrad,
                                                          inImgH,
                                                          inImgW,
                                                          inputH,
                                                          inputW,
                                                          outGrad,
                                                          outImgH,
                                                          outImgW,
                                                          outputH,
                                                          outputW,
                                                          numChannels,
                                                          ratioH,
                                                          ratioW);
L
liaogang 已提交
929
  CHECK_SYNC("hl_bilinear_backward failed");
L
liaogang 已提交
930 931
}

L
liaogang 已提交
932 933 934 935 936 937 938
__global__ void maxoutFpCompute(size_t nthreads,
                                const real* inData,
                                real* outData,
                                int* idData,
                                size_t size,
                                size_t featLen,
                                size_t groups) {
939
  int index = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
940
  if (index < nthreads) {
941 942 943 944
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
L
liaogang 已提交
945 946
    size_t data_idx =
        (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
947 948 949 950 951 952 953 954 955 956 957 958 959 960
    real max = inData[data_idx];
    int maxId = 0;
    for (size_t g = 1; g < groups; ++g) {
      real tmp = inData[data_idx + g * featLen];
      if (tmp > max) {
        max = tmp;
        maxId = g;
      }
    }
    outData[index] = max;
    idData[index] = maxId;
  }
}

L
liaogang 已提交
961 962 963 964 965 966 967
void hl_maxout_forward(const real* inData,
                       real* outData,
                       int* idData,
                       size_t batchSize,
                       size_t size,
                       size_t featLen,
                       size_t groups) {
968 969
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
970 971
  maxoutFpCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(
      num_kernels, inData, outData, idData, size, featLen, groups);
972 973 974
  CHECK_SYNC("hl_maxout_forward failed");
}

L
liaogang 已提交
975 976 977 978 979 980 981
__global__ void maxoutBpCompute(size_t nthreads,
                                real* inGrad,
                                const real* outGrad,
                                const int* idData,
                                size_t size,
                                size_t featLen,
                                size_t groups) {
982
  int index = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
983
  if (index < nthreads) {
984 985 986 987 988
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
    size_t newIndex = batch_idx * size;
L
liaogang 已提交
989 990
    size_t gradIdx =
        (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
991 992 993 994
    (inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i];
  }
}

L
liaogang 已提交
995 996 997 998 999 1000 1001
void hl_maxout_backward(real* inGrad,
                        const real* outGrad,
                        const int* idData,
                        size_t batchSize,
                        size_t size,
                        size_t featLen,
                        size_t groups) {
1002 1003
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
1004 1005
  maxoutBpCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(
      num_kernels, inGrad, outGrad, idData, size, featLen, groups);
1006 1007
  CHECK_SYNC("hl_maxout_backward failed");
}