RowConvOpGpu.cu 10.5 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "RowConvOp.h"

namespace paddle {

template<int BLOCK_H, int BLOCK_W>
__global__ void KeRowConv(real* y, const real* x,  const real* w,
    const int* starts, const int height, const int width,
    const int numSeq, const int context) {

  const int tidx = threadIdx.x;
  const int tidy = threadIdx.y;
  const int blky = blockDim.y;
  const int gidx = blockIdx.x * blockDim.x;

  __shared__ real sw[BLOCK_H][BLOCK_W];

  for (int i = tidy; i < context; i += blky) {
    sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
  }
J
jc 已提交
35

D
dangqingqing 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  __syncthreads();

  for (int i = 0; i < numSeq; ++i) {
    const int start = starts[i];
    const int end = starts[i + 1];
    const int steps = end - start;
    for (int j = tidy; j < steps; j += blky) {
      real sum = 0;
      int off = (start + j) * width;
      for (int t = 0; t < context; ++t) {
        if ((start + j + t) < end) {
          int xoff = off + t * width;
          real xVal = gidx + tidx < width ? x[xoff + gidx + tidx] : 0.0;
          sum += sw[t][tidx] * xVal;
        }
      }
      if (gidx + tidx < width) {
        y[off + gidx + tidx] += sum;
      }
    }
  }
}

__global__ void KeRowConv2(real* y, const real* x,  const real* w,
    const int* starts, const int height, const int width,
    const int numSeq, const int context) {
  const int tidx = threadIdx.x;
  const int tidy = threadIdx.y;
  const int blky = blockDim.y;
  const int gidx = blockIdx.x * blockDim.x;

  for (int i = 0; i < numSeq; ++i) {
    const int start = starts[i];
    const int end = starts[i + 1];
    const int steps = end - start;
    for (int j = tidy; j < steps; j += blky) {
      int off = (start + j) * width;
      real sum = 0;
      for (int t = 0; t < context && (start + j + t) < end; ++t) {
        int xoff = off + t * width;
        real xd = gidx + tidx < width ? x[xoff + gidx + tidx] : 0.0;
        real wd = gidx + tidx < width ? w[t * width + gidx + tidx] : 0.0;
        sum += wd * xd;
      }
      if (gidx + tidx < width) {
        y[off + gidx + tidx] += sum;
      }
    }
  }
}



template <>
void RowConv<DEVICE_TYPE_GPU>(GpuMatrix& out,
                              const GpuMatrix& in,
                              const GpuMatrix& filter,
                              const GpuIVector& seq) {
  const size_t numSeq = seq.getSize() - 1;
  const size_t contextLength = filter.getHeight();
  const size_t height = in.getHeight();
  const size_t width = in.getWidth();

  real* y = out.getData();
  const real* x = in.getData();
  const real* w = filter.getData();
  const int* starts = seq.getData();

  dim3 dimBlock(32, 32);
  dim3 dimGrid(DIVUP(width, dimBlock.x), 1);

  if (contextLength <= 32) {
    KeRowConv<32, 32><<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
      (y, x, w, starts, height, width, numSeq, contextLength);
  } else {
    KeRowConv2<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
      (y, x, w, starts, height, width, numSeq, contextLength);
  }
  CHECK_SYNC("RowConv");
}


template<int BLOCK_H, int BLOCK_W, int CONTEXT>
__global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
    const int* starts, const int height, const int width, const int numSeq,
    const int context) {

  const int tidx = threadIdx.x;
  const int tidy = threadIdx.y;
  const int blky = blockDim.y;
  const int gidx = blockIdx.x * blockDim.x;

D
dangqingqing 已提交
128 129
  __shared__ real sh_x[BLOCK_W][BLOCK_H];
  __shared__ real sh_dy[BLOCK_W][BLOCK_H + CONTEXT - 1];
D
dangqingqing 已提交
130 131
  __shared__ real sh_dw[CONTEXT][BLOCK_W];

D
dangqingqing 已提交
132 133
  if (tidy < context) {
    sh_dw[tidy][tidx] = 0.0;
D
dangqingqing 已提交
134 135 136 137 138 139 140
  }
  __syncthreads();

  for (int i = 0; i < numSeq; ++i) {
    const int start = starts[i];
    const int end = starts[i + 1];
    const int steps = end - start;
D
dangqingqing 已提交
141 142
    const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H;
    for (int j = tidy; j < size; j += BLOCK_H) {
D
dangqingqing 已提交
143 144 145 146
      int xoff = gidx + tidx;
      int yoff = start + j;

      // transpose
J
jc 已提交
147 148 149 150
      sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
      x[yoff * width + xoff] : 0.0;
      sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ?
      dy[yoff * width + xoff] : 0.0;
D
dangqingqing 已提交
151 152 153
      __syncthreads();
      if (tidy < (context - 1)) {
        yoff = yoff - context + 1;
J
jc 已提交
154 155
        sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ?
        dy[yoff * width + xoff] : 0.0;
D
dangqingqing 已提交
156
      }
D
dangqingqing 已提交
157 158 159
      __syncthreads();

      for (int t = 0; t < context; t++) {
D
dangqingqing 已提交
160 161
        real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx + context - 1 - t];
        __syncthreads();
D
dangqingqing 已提交
162
        // warp size and blockDim.x is 32.
D
dangqingqing 已提交
163 164 165 166 167 168
        val += __shfl_down(val, 16);
        val += __shfl_down(val, 8);
        val += __shfl_down(val, 4);
        val += __shfl_down(val, 2);
        val += __shfl_down(val, 1);
        __syncthreads();
D
dangqingqing 已提交
169 170 171 172 173 174 175 176
        if (tidx == 0) {
          sh_dw[t][tidy] += val;
        }
        __syncthreads();
      }
    }
  }

D
dangqingqing 已提交
177
  for (int t = tidy; (t < context) && ((gidx + tidx) < width); t += blky) {
D
dangqingqing 已提交
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    dw[t * width + gidx + tidx] += sh_dw[t][tidx];
  }
}

template<int BLOCK_H, int BLOCK_W>
__global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy,
    const int* starts, const int height, const int width, const int numSeq,
    const int context) {

  const int tidx = threadIdx.x;
  const int tidy = threadIdx.y;
  const int gidx = blockIdx.x * blockDim.x;

  __shared__ real sh_x[BLOCK_H][BLOCK_W];
  __shared__ real sh_dy[BLOCK_H][BLOCK_W];

  for (int i = 0; i < numSeq; ++i) {
    const int start = starts[i];
    const int end = starts[i + 1];
    const int steps = end - start;
D
dangqingqing 已提交
198 199 200

    const int size = ((steps + BLOCK_H - 1)/BLOCK_H) * BLOCK_H;
    for (int j = tidy; j < size; j += BLOCK_H) {
D
dangqingqing 已提交
201 202 203 204
      int xoff = gidx + tidx;
      int yoff = start + j;

      // transpose
J
jc 已提交
205 206
      sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
      x[yoff * width + xoff] : 0.0;
D
dangqingqing 已提交
207 208 209
      __syncthreads();

      for (int t = 0; t < context; t++) {
J
jc 已提交
210 211
        sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start &&
        yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
D
dangqingqing 已提交
212 213 214 215
        __syncthreads();

        real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
        __syncthreads();
D
dangqingqing 已提交
216
        // warp size and blockDim.x is 32.
D
dangqingqing 已提交
217 218 219 220 221 222 223
        val += __shfl_down(val, 16);
        val += __shfl_down(val, 8);
        val += __shfl_down(val, 4);
        val += __shfl_down(val, 2);
        val += __shfl_down(val, 1);
        __syncthreads();

D
dangqingqing 已提交
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
        if (tidx == 0 && (gidx + tidy) < width) {
          dw[t*width + gidx + tidy] += val;
        }
      }
    }
  }
}

template<int BLOCK_H, int BLOCK_W>
__global__ void KeRowConvBwData(real* dx, const real* w, const real* dy,
    const int* starts, const int height, const int width, const int numSeq,
    const int context) {

  const int tidx = threadIdx.x;
  const int tidy = threadIdx.y;
  const int blky = blockDim.y;
  const int gidx = blockIdx.x * blockDim.x;

  __shared__ real sw[BLOCK_H][BLOCK_W];

  for (int i = tidy; i < context; i += blky) {
    sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
  }
J
jc 已提交
247

D
dangqingqing 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
  __syncthreads();

  for (int i = 0; i < numSeq; ++i) {
    const int start = starts[i];
    const int end = starts[i + 1];
    const int steps = end - start;
    for (int j = tidy; j < steps; j += blky) {
      real sum = 0;
      int off = (start + j) * width;
      for (int t = 0; t < context && (j - t) >= 0; ++t) {
        int dyOff = off - t * width;
        real dyVal = gidx + tidx < width ? dy[dyOff + gidx + tidx] : 0.0;
        sum += sw[t][tidx] * dyVal;
      }
      if (gidx + tidx < width) {
        dx[off + gidx + tidx] += sum;
      }
    }
  }
}

__global__ void KeRowConvBwData2(real* dx, const real* w, const real* dy,
    const int* starts, const int height, const int width, const int numSeq,
    const int context) {

  const int tidx = threadIdx.x;
  const int tidy = threadIdx.y;
  const int blky = blockDim.y;
  const int gidx = blockIdx.x * blockDim.x;

  for (int i = 0; i < numSeq; ++i) {
    const int start = starts[i];
    const int end = starts[i + 1];
    const int steps = end - start;
    for (int j = tidy; j < steps; j += blky) {
      real sum = 0;
      int off = (start + j) * width;
      for (int t = 0; t < context && (j - t) >= 0; ++t) {
        int dyOff = off - t * width;
        real dyVal = gidx + tidx < width ? dy[dyOff + gidx + tidx] : 0.0;
        real wVal = gidx + tidx < width ? w[t * width + gidx + tidx] : 0.0;
        sum += wVal * dyVal;
      }
      if (gidx + tidx < width) {
        dx[off + gidx + tidx] += sum;
      }
    }
  }
}


template <>
void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
                              const GpuMatrix& in,
                              const GpuMatrix& filter,
                              GpuMatrix& inG,
                              GpuMatrix& filterG,
                              const GpuIVector& seq) {
  const size_t numSeq = seq.getSize() - 1;
  const size_t contextLength = filter.getHeight();
  const size_t height = in.getHeight();
  const size_t width = in.getWidth();

  const real* dy = outG.getData();
  const real* x = in.getData();
  const real* w = filter.getData();
  const int* starts = seq.getData();

D
dangqingqing 已提交
316 317 318 319
  if (filterG) {
    dim3 dimBlock(32, 32);
    dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
    real* dw = filterG.getData();
J
jc 已提交
320
    if (contextLength <= 32) {
D
dangqingqing 已提交
321
      KeRowConvBwWeight<32, 32, 32>
D
dangqingqing 已提交
322 323 324 325 326 327 328
        <<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
        (dw, x, dy, starts, height, width, numSeq, contextLength);
    } else {
      KeRowConvBwWeight2<32, 32>
        <<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
        (dw, x, dy, starts, height, width, numSeq, contextLength);
    }
D
dangqingqing 已提交
329 330
  }

D
dangqingqing 已提交
331 332 333 334 335 336 337 338 339 340 341 342 343
  if (inG) {
    real* dx = inG.getData();
    dim3 dimBlock2(32, 32);
    dim3 dimGrid2(DIVUP(width, dimBlock2.x), 1);
    if (contextLength <= 64) {
      KeRowConvBwData<32, 64>
        <<<dimGrid2, dimBlock2, 0, STREAM_DEFAULT>>>
        (dx, w, dy, starts, height, width, numSeq, contextLength);
    } else {
      KeRowConvBwData2
        <<<dimGrid2, dimBlock2, 0, STREAM_DEFAULT>>>
        (dx, w, dy, starts, height, width, numSeq, contextLength);
    }
D
dangqingqing 已提交
344 345 346 347 348 349
  }

  CHECK_SYNC("RowConvGrad");
}

}  // namespace paddle