hl_cuda_sequence.cu 13.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"

__global__ void KeMaxSequenceForward(real *input,
                                     const int *sequence,
                                     real* output,
                                     int *index,
                                     int numSequences,
                                     int dim) {
  int dimIdx = threadIdx.x;
  int sequenceId = blockIdx.x;
  if (sequenceId >= numSequences) return;
  int start = sequence[sequenceId];
  int end = sequence[sequenceId+1];

  for (int i = dimIdx; i < dim; i += blockDim.x) {
    real tmp = -HL_FLOAT_MAX;
    int tmpId = -1;
    for (int insId = start; insId < end; insId++) {
      if (tmp < input[insId*dim + i]) {
        tmp = input[insId*dim + i];
        tmpId = insId;
      }
    }
    output[sequenceId*dim + i] = tmp;
    index[sequenceId*dim + i] = tmpId;
  }
}

void hl_max_sequence_forward(real* input,
                             const int* sequence,
                             real* output,
                             int *index,
                             int numSequences,
                             int dim) {
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(index);

  dim3 threads(256, 1);
  dim3 grid(numSequences, 1);
  KeMaxSequenceForward<<< grid, threads, 0, STREAM_DEFAULT >>>
      (input, sequence, output, index, numSequences, dim);
  CHECK_SYNC("hl_max_sequence_forward failed");
}

__global__ void KeMaxSequenceBackward(real *outputGrad,
                                      int *index,
                                      real* inputGrad,
                                      int numSequences,
                                      int dim) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int colIdx = idx % dim;
  if (idx < numSequences*dim) {
    int insId = index[idx];
    inputGrad[insId * dim + colIdx] += outputGrad[idx];
  }
}

void hl_max_sequence_backward(real* outputGrad,
                              int *index,
                              real* inputGrad,
                              int numSequences,
                              int dim) {
  CHECK_NOTNULL(outputGrad);
  CHECK_NOTNULL(index);
  CHECK_NOTNULL(inputGrad);

  unsigned int blocks = (numSequences * dim + 128 - 1) / 128;
  dim3 threads(128, 1);
  dim3 grid(blocks, 1);
  KeMaxSequenceBackward<<< grid, threads, 0, STREAM_DEFAULT >>>
      (outputGrad, index, inputGrad, numSequences, dim);
  CHECK_SYNC("hl_max_sequence_backward failed");
}

template<int blockDimX, int blockDimY, int gridDimX, bool AddRow>
__global__ void KeMatrixAddRows(real* output,
                                real* table,
                                int* ids,
                                int numSamples,
                                int tableSize,
                                int dim) {
  int idx = threadIdx.x;
  int idy = threadIdx.y;
  int sampleId = blockIdx.x + idy * gridDimX;

  while (sampleId < numSamples) {
    int tableId = ids[sampleId];
    if ((0 <= tableId) && (tableId < tableSize)) {
      real *outputData = output + sampleId * dim;
      real *tableData = table + tableId * dim;
      for (int i = idx; i < dim; i += blockDimX) {
        if (AddRow == 0) {
          outputData[i] += tableData[i];
        } else {
113
          paddle::paddleAtomicAdd(&tableData[i], outputData[i]);
Z
zhangjinchao01 已提交
114 115 116 117 118 119 120 121 122 123 124
        }
      }
    }
    sampleId += blockDimY*gridDimX;
  }
}

template<int blockDimX, int blockDimY, int gridDimX, bool seq2batch, bool isAdd>
__global__
void KeSequence2Batch(real *batch,
                      real *sequence,
125
                      const int *batchIndex,
Z
zhangjinchao01 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
                      int seqWidth,
                      int batchCount) {
  int idx = threadIdx.x;
  int idy = threadIdx.y;
  int id = blockIdx.x + idy * gridDimX;
  while (id < batchCount) {
    int seqId = batchIndex[id];
    real* batchData = batch + id*seqWidth;
    real* seqData = sequence + seqId*seqWidth;
    for (int i = idx; i < seqWidth; i += blockDimX) {
      if (seq2batch) {
        if (isAdd) {
          batchData[i] += seqData[i];
        } else {
          batchData[i] = seqData[i];
        }
      } else {
        if (isAdd) {
          seqData[i] += batchData[i];
        } else {
          seqData[i] = batchData[i];
        }
      }
    }
    id += blockDimY*gridDimX;
  }
}

void hl_sequence2batch_copy(real *batch,
                            real *sequence,
156
                            const int *batchIndex,
Z
zhangjinchao01 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
                            int seqWidth,
                            int batchCount,
                            bool seq2batch) {
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(batch);
  CHECK_NOTNULL(batchIndex);

  dim3 threads(128, 8);
  dim3 grid(8, 1);
  if (seq2batch) {
    KeSequence2Batch<128, 8, 8, 1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  } else {
    KeSequence2Batch<128, 8, 8, 0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  }
  CHECK_SYNC("hl_sequence2batch_copy failed");
}

void hl_sequence2batch_add(real *batch,
                           real *sequence,
                           int *batchIndex,
                           int seqWidth,
                           int batchCount,
                           bool seq2batch) {
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(batch);
  CHECK_NOTNULL(batchIndex);

  dim3 threads(128, 8);
  dim3 grid(8, 1);
  if (seq2batch) {
    KeSequence2Batch<128, 8, 8, 1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  } else {
    KeSequence2Batch<128, 8, 8, 0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  }
  CHECK_SYNC("hl_sequence2batch_add failed");
}

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
template<bool normByTimes, bool seq2batch>
__global__
void KeSequence2BatchPadding(real* batch,
                             real* sequence,
                             const int* sequenceStartPositions,
                             const size_t sequenceWidth,
                             const size_t maxSequenceLength,
                             const size_t numSequences) {
  int batchIdx = blockIdx.y;
  int sequenceStart = sequenceStartPositions[batchIdx];
  int sequenceLength = sequenceStartPositions[batchIdx + 1] - sequenceStart;

  int sequenceIdx = blockIdx.x * blockDim.y + threadIdx.y;
  int batchBaseIdx = (sequenceIdx * numSequences + batchIdx) * sequenceWidth;
  int sequenceBaseIdx = (sequenceStart + sequenceIdx) * sequenceWidth;

L
Liu Yiqun 已提交
214 215
  real scale = normByTimes ? (1.0f / (real)sequenceLength) : 1.0f;

216 217 218
  if (sequenceIdx < sequenceLength) {
    if (seq2batch) {
      /* sequence -> batch */
L
Liu Yiqun 已提交
219 220
      for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
        batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i];
221 222 223
      }
    } else {
      /* batch -> sequence */
L
Liu Yiqun 已提交
224 225
      for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
        sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i];
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
      }
    }
  } else if (sequenceIdx < maxSequenceLength) {
    if (seq2batch) {
      /* sequence -> batch */
      for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
        batch[batchBaseIdx + i] = 0;
      }
    }
  }
}

void hl_sequence2batch_copy_padding(real* batch,
                                    real* sequence,
                                    const int* sequenceStartPositions,
                                    const size_t sequenceWidth,
                                    const size_t maxSequenceLength,
                                    const size_t numSequences,
                                    bool normByTimes,
                                    bool seq2batch) {
  CHECK_NOTNULL(batch);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(sequenceStartPositions);

  if (!normByTimes && numSequences == 1) {
    size_t elementCount = maxSequenceLength * sequenceWidth;
    if (seq2batch) {
      /* sequence -> batch */
      hl_memcpy_device2device(batch, sequence, sizeof(real) * elementCount);
    } else {
      /* batch -> sequence */
      hl_memcpy_device2device(sequence, batch, sizeof(real) * elementCount);
    }
    return;
  }

  const int CUDA_BLOCK_SIZE = 512;

  /* At least use 32 threads to copy sequenceWidth elements,
     and at least 8 elements for each thread. */
  int blockDimX = ((((sequenceWidth + 7) >> 3) + 31) >> 5) << 5;
  blockDimX = (blockDimX < CUDA_BLOCK_SIZE) ? blockDimX : CUDA_BLOCK_SIZE;

  int blockDimY = CUDA_BLOCK_SIZE / blockDimX;
  dim3 threads(blockDimX, blockDimY);

D
dangqingqing 已提交
272
  int gridDimX = (maxSequenceLength + blockDimY - 1) / blockDimY;
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
  int gridDimY = numSequences;
  dim3 grid(gridDimX, gridDimY);

  if (seq2batch) {
    /* sequence -> batch */
    if (normByTimes) {
      KeSequence2BatchPadding<1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    } else {
      KeSequence2BatchPadding<0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    }
  } else {
    /* batch -> sequence */
    if (normByTimes) {
      KeSequence2BatchPadding<1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    } else {
      KeSequence2BatchPadding<0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    }
  }

  CHECK_SYNC("hl_sequence2batch_copy_padding failed");
}

Z
zhangjinchao01 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
__device__ inline float my_rsqrt(float x) {
  return rsqrtf(x);
}

__device__ inline double my_rsqrt(double x) {
  return rsqrt(x);
}

__global__ void KeSequenceAvgForward(real* dst,
                                     real* src,
                                     const int* starts,
                                     int height,
                                     int width,
                                     const int mode) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int row = gid / width;
  int col = gid % width;

  if (gid < height * width) {
    int start = starts[row];
    int end = starts[row + 1];
    int seqLength = end - start;
    if (seqLength == 0) return;
    real sum = 0.0;
L
Luo Tao 已提交
327 328
    for (int i = start; i < end; i++) {
      sum += src[i * width + col];
Z
zhangjinchao01 已提交
329 330 331
    }
    sum = mode == 1 ? sum :
        (mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength));
H
hedaoyuan 已提交
332
    dst[gid] += sum;
Z
zhangjinchao01 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
  }
}

void hl_sequence_avg_forward(real* dst,
                             real* src,
                             const int* starts,
                             int height,
                             int width,
                             const int mode) {
  CHECK_NOTNULL(dst);
  CHECK_NOTNULL(src);
  CHECK_NOTNULL(starts);

  int block = 512;
  int grid = DIVUP(width * height, 512);

  CHECK(mode == 0 || mode == 1 || mode == 2)
    << "mode error in hl_sequence_avg_forward!";

  KeSequenceAvgForward<<< grid, block, 0, STREAM_DEFAULT >>>
           (dst, src, starts, height, width, mode);
  CHECK_SYNC("hl_sequence_avg_forward failed");
}
L
Luo Tao 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400

__global__ void KeSequenceAvgBackward(real* dst,
                                      real* src,
                                      const int* starts,
                                      int height,
                                      int width,
                                      const int mode) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int row = gid / width;
  int col = gid % width;

  if (gid < height * width) {
    int start = starts[row];
    int end = starts[row + 1];
    int seqLength = end - start;
    if (seqLength == 0) return;
    real grad = src[gid];
    grad = mode == 1 ? grad :
        (mode == 0 ? grad / seqLength : grad * my_rsqrt((real)seqLength));
    for (int i = start; i < end; i++) {
      dst[i * width + col] += grad;
    }
  }
}

void hl_sequence_avg_backward(real* dst,
                              real* src,
                              const int* starts,
                              int height,
                              int width,
                              const int mode) {
  CHECK_NOTNULL(dst);
  CHECK_NOTNULL(src);
  CHECK_NOTNULL(starts);

  int block = 512;
  int grid = DIVUP(width * height, 512);

  CHECK(mode == 0 || mode == 1 || mode == 2)
    << "mode error in hl_sequence_avg_backward!";

  KeSequenceAvgBackward<<< grid, block, 0, STREAM_DEFAULT >>>
           (dst, src, starts, height, width, mode);
  CHECK_SYNC("hl_sequence_avg_backward failed");
}