PadOpGpu.cu 5.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
D
dangqingqing 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "PadOp.h"
L
liaogang 已提交
16
#include "hl_base.h"
D
dangqingqing 已提交
17 18 19

namespace paddle {

L
liaogang 已提交
20 21 22 23 24 25 26 27 28 29 30 31
__global__ void KePad(real* outputs,
                      const real* inputs,
                      int inC,
                      int inH,
                      int inW,
                      int padc,
                      int padh,
                      int padw,
                      int outC,
                      int outH,
                      int outW,
                      int nthreads) {
D
dangqingqing 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < nthreads) {
    const int w = idx % inW;
    const int h = (idx / inW) % inH;
    const int c = (idx / inW / inH) % inC;
    const int n = idx / inW / inH / inC;

    const int off = ((n * outC + c + padc) * outH + h + padh) * outW + padw + w;
    outputs[off] = inputs[idx];
  }
}

template <>
void Pad<DEVICE_TYPE_GPU>(real* outputs,
                          const real* inputs,
                          const int num,
                          const int inC,
                          const int inH,
                          const int inW,
D
dangqingqing 已提交
51
                          const PadConf& pad) {
D
dangqingqing 已提交
52 53 54
  size_t nth = num * inC * inH * inW;
  int blockSize = 1024;
  int gridSize = (nth + 1024 - 1) / 1024;
L
Luo Tao 已提交
55 56 57
  int cstart = pad.channel[0], cend = pad.channel[1];
  int hstart = pad.height[0], hend = pad.height[1];
  int wstart = pad.width[0], wend = pad.width[1];
D
dangqingqing 已提交
58 59 60
  int outC = inC + cstart + cend;
  int outH = inH + hstart + hend;
  int outW = inW + wstart + wend;
L
liaogang 已提交
61 62 63 64 65 66 67 68 69 70 71 72
  KePad<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(outputs,
                                                    inputs,
                                                    inC,
                                                    inH,
                                                    inW,
                                                    cstart,
                                                    hstart,
                                                    wstart,
                                                    outC,
                                                    outH,
                                                    outW,
                                                    nth);
D
dangqingqing 已提交
73 74 75
  CHECK_SYNC("Pad");
}

L
liaogang 已提交
76 77 78 79 80 81 82 83 84 85 86 87
__global__ void KePadDiff(real* inGrad,
                          const real* outGrad,
                          int inC,
                          int inH,
                          int inW,
                          int padc,
                          int padh,
                          int padw,
                          int outC,
                          int outH,
                          int outW,
                          int nthreads) {
D
dangqingqing 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx < nthreads) {
    const int w = idx % inW;
    const int h = (idx / inW) % inH;
    const int c = (idx / inW / inH) % inC;
    const int n = idx / inW / inH / inC;

    const int off = ((n * outC + c + padc) * outH + h + padh) * outW + padw + w;
    inGrad[idx] += outGrad[off];
  }
}

template <>
void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
                              const real* outGrad,
                              const int num,
                              const int inC,
                              const int inH,
                              const int inW,
D
dangqingqing 已提交
107
                              const PadConf& pad) {
D
dangqingqing 已提交
108 109 110
  int nth = num * inC * inH * inW;
  int blockSize = 1024;
  int gridSize = (nth + 1024 - 1) / 1024;
L
Luo Tao 已提交
111 112 113
  int cstart = pad.channel[0], cend = pad.channel[1];
  int hstart = pad.height[0], hend = pad.height[1];
  int wstart = pad.width[0], wend = pad.width[1];
D
dangqingqing 已提交
114 115 116
  int outC = inC + cstart + cend;
  int outH = inH + hstart + hend;
  int outW = inW + wstart + wend;
L
liaogang 已提交
117 118 119 120 121 122 123 124 125 126 127 128
  KePadDiff<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(inGrad,
                                                        outGrad,
                                                        inC,
                                                        inH,
                                                        inW,
                                                        cstart,
                                                        hstart,
                                                        wstart,
                                                        outC,
                                                        outH,
                                                        outW,
                                                        nth);
D
dangqingqing 已提交
129 130 131 132
  CHECK_SYNC("PadGrad");
}

}  // namespace paddle