hl_cuda_sequence.cu 21.2 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "hl_device_functions.cuh"
#include "paddle/utils/Logging.h"

__global__ void KeMaxSequenceForward(real *input,
                                     const int *sequence,
                                     real* output,
                                     int *index,
                                     int numSequences,
                                     int dim) {
  int dimIdx = threadIdx.x;
  int sequenceId = blockIdx.x;
  if (sequenceId >= numSequences) return;
  int start = sequence[sequenceId];
  int end = sequence[sequenceId+1];

  for (int i = dimIdx; i < dim; i += blockDim.x) {
    real tmp = -HL_FLOAT_MAX;
    int tmpId = -1;
    for (int insId = start; insId < end; insId++) {
      if (tmp < input[insId*dim + i]) {
        tmp = input[insId*dim + i];
        tmpId = insId;
      }
    }
    output[sequenceId*dim + i] = tmp;
    index[sequenceId*dim + i] = tmpId;
  }
}

void hl_max_sequence_forward(real* input,
                             const int* sequence,
                             real* output,
                             int *index,
                             int numSequences,
                             int dim) {
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(index);

  dim3 threads(256, 1);
  dim3 grid(numSequences, 1);
  KeMaxSequenceForward<<< grid, threads, 0, STREAM_DEFAULT >>>
      (input, sequence, output, index, numSequences, dim);
  CHECK_SYNC("hl_max_sequence_forward failed");
}

__global__ void KeMaxSequenceBackward(real *outputGrad,
                                      int *index,
                                      real* inputGrad,
                                      int numSequences,
                                      int dim) {
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int colIdx = idx % dim;
  if (idx < numSequences*dim) {
    int insId = index[idx];
    inputGrad[insId * dim + colIdx] += outputGrad[idx];
  }
}

void hl_max_sequence_backward(real* outputGrad,
                              int *index,
                              real* inputGrad,
                              int numSequences,
                              int dim) {
  CHECK_NOTNULL(outputGrad);
  CHECK_NOTNULL(index);
  CHECK_NOTNULL(inputGrad);

  unsigned int blocks = (numSequences * dim + 128 - 1) / 128;
  dim3 threads(128, 1);
  dim3 grid(blocks, 1);
  KeMaxSequenceBackward<<< grid, threads, 0, STREAM_DEFAULT >>>
      (outputGrad, index, inputGrad, numSequences, dim);
  CHECK_SYNC("hl_max_sequence_backward failed");
}

template <bool padding>
__global__ void KeContextProjectionForward(real* input,
                                           const int* sequence,
                                           real* weightData,
                                           real* output,
                                           int inputDim,
                                           int contextLength,
                                           int contextStart,
                                           int beginPad) {
  int idx = threadIdx.x;
  int blockSize = blockDim.x;
  int sequenceId = blockIdx.x;
  int seqStart = sequence[sequenceId];
  int seqEnd = sequence[sequenceId+1];
  real value = 0;

  int instances = seqEnd - seqStart + contextLength - 1;
  output += seqStart * inputDim * contextLength;
  input += seqStart * inputDim;
  for (int k = 0; k <= inputDim / blockSize; k++) {
    if (idx < inputDim) {
      for (int i = 0; i < instances; i++) {
        // i + contextStart;
        if ((i + contextStart) < 0) {
          if (padding) {
            value = weightData[i * inputDim + idx];
          } else {
            continue;
          }
        } else if ((i + contextStart) >= (seqEnd - seqStart)) {
          if (padding) {
            value =
              weightData[(beginPad + i + contextStart - (seqEnd - seqStart)) *
                         inputDim + idx];
          } else {
            continue;
          }
        } else {
          value = input[(i + contextStart) * inputDim + idx];
        }

        int outx = (i - contextLength) < 0 ? i : (contextLength - 1);
        int outy = (i - contextLength) < 0 ? 0 : (i - (contextLength - 1));
        real* output_r =
          output + outy * inputDim * contextLength + outx * inputDim;
        for (int j = outy; j < seqEnd - seqStart; j++) {
          output_r[idx] += value;
          if (j - outy == outx) break;
          output_r += (contextLength - 1) * inputDim;
        }
      }
    }
    idx += blockSize;
  }
}

void hl_context_projection_forward(real* input,
                                   const int* sequence,
                                   real* weightData,
                                   real* output,
                                   int numSequences,
                                   int inputDim,
                                   int contextLength,
                                   int contextStart,
                                   int beginPad,
                                   bool isPadding) {
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(output);
  CHECK(!isPadding || weightData);

  int blockSize = 128;
  int blocksX = numSequences;
  int blocksY = 1;
  dim3 threads(blockSize, 1);
  dim3 grid(blocksX, blocksY);

  if (isPadding) {
    KeContextProjectionForward<true><<< grid, threads, 0, STREAM_DEFAULT >>>
      (input, sequence, weightData, output, inputDim,
       contextLength, contextStart, beginPad);
  } else  {
    KeContextProjectionForward<false><<< grid, threads, 0, STREAM_DEFAULT >>>
      (input, sequence, weightData, output, inputDim,
       contextLength, contextStart, beginPad);
  }
  CHECK_SYNC("hl_context_projection_forward failed");
}

__global__ void KeContextProjectionBackwardData(real* outputGrad,
                                                const int* sequence,
                                                real* inputGrad,
                                                int inputDim,
                                                int contextLength,
                                                int contextStart) {
  int idx = threadIdx.x;
  int blockSize = blockDim.x;
  int sequenceId = blockIdx.x;
  int seqStart = sequence[sequenceId];
  int seqEnd = sequence[sequenceId+1];
  real value = 0;

  int instances = seqEnd - seqStart + contextLength - 1;
  outputGrad += seqStart * inputDim * contextLength;
  inputGrad += seqStart * inputDim;
  for (int k = 0; k <= inputDim / blockSize; k++) {
    if (idx < inputDim) {
      for (int i = 0; i < instances; i++) {
        if ((i + contextStart) < 0) {
          continue;
        } else if ((i + contextStart) >= (seqEnd - seqStart)) {
          continue;
        } else {
          // value = 0;
          value = inputGrad[(i + contextStart) * inputDim + idx];
        }

        int outx = (i - contextLength) < 0 ? i : (contextLength - 1);
        int outy = (i - contextLength) < 0 ? 0 : (i - (contextLength - 1));
        real* output_r =
          outputGrad + outy * inputDim * contextLength + outx * inputDim;
        for (int j = outy; j < seqEnd - seqStart; j++) {
          value += output_r[idx];
          if (j - outy == outx) break;
          output_r += (contextLength - 1) * inputDim;
        }
        inputGrad[(i + contextStart) * inputDim + idx] = value;
      }
    }
    idx += blockSize;
  }
}

void hl_context_projection_backward_data(real* outputGrad,
                                         const int* sequence,
                                         real* inputGrad,
                                         int numSequences,
                                         int inputDim,
                                         int contextLength,
                                         int contextStart) {
  CHECK_NOTNULL(outputGrad);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(inputGrad);

  int blockSize = 128;
  int blocksX = numSequences;
  int blocksY = 1;
  dim3 threads(blockSize, 1);
  dim3 grid(blocksX, blocksY);
  KeContextProjectionBackwardData<<< grid, threads, 0, STREAM_DEFAULT >>>
    (outputGrad, sequence, inputGrad, inputDim, contextLength, contextStart);
  CHECK_SYNC("hl_context_projection_backward_data failed");
}

template<int THREADS_X, int THREADS_Y>
__global__ void KeContextProjectionBackwardWeight(real* outputGrad,
                                                  const int* sequence,
                                                  real* weightGrad,
                                                  int numSequences,
                                                  int weightDim,
                                                  int contextLength,
                                                  int contextStart,
                                                  int beginPad) {
  __shared__ real sum_s[THREADS_Y][THREADS_X];
  int padOfBlock = (weightDim + THREADS_X - 1) / THREADS_X;
  const int idx = threadIdx.x;
  const int idy = threadIdx.y;
  int padId = blockIdx.x / padOfBlock;
  int weightIdx = idx + THREADS_X * (blockIdx.x % padOfBlock);
  int instanceId;
  real value = 0;
  real* output_r;

  sum_s[idy][idx] = 0.0f;
  if (weightIdx < weightDim) {
    for (int seqId = idy; seqId < numSequences; seqId += THREADS_Y) {
      int seqStart = sequence[seqId];
      int seqEnd = sequence[seqId+1];
      output_r = outputGrad + seqStart * weightDim * contextLength;

      if (contextStart < 0) {
        if (padId + contextStart < 0) {
          instanceId = padId;
        } else {
          // beginPad > 0;
          instanceId = (padId - beginPad) + (seqEnd - seqStart) - contextStart;
        }
      } else {
        if (padId + (seqEnd - seqStart) < contextStart) {
          continue;
        } else {
          // beginPad == 0;
          instanceId = padId + (seqEnd - seqStart) - contextStart;
        }
      }

      int outx = (instanceId - contextLength) < 0 ?
                 instanceId : (contextLength - 1);
      int outy = (instanceId - contextLength) < 0 ?
                 0 : (instanceId - (contextLength - 1));
      output_r += outy * weightDim * contextLength + outx * weightDim;
      for (int j = outy; j < seqEnd - seqStart; j++) {
        value += output_r[weightIdx];
        if (j - outy == outx) break;
        output_r += (contextLength - 1) * weightDim;
      }
    }
    sum_s[idy][idx] = value;
  }
  __syncthreads();

  for (int stride = THREADS_Y/2; stride > 0; stride = stride/2) {
    if (idy < stride) {
      sum_s[idy][idx] += sum_s[idy + stride][idx];
    }
    __syncthreads();
  }
  __syncthreads();

  if (weightIdx < weightDim) {
    if (idy == 0) {
      weightGrad[padId * weightDim + weightIdx] += sum_s[0][idx];
    }
  }
}

void hl_context_projection_backward_weight(real* outputGrad,
                                           const int* sequence,
                                           real* weightGrad,
                                           int numSequences,
                                           int weightDim,
                                           int totalPad,
                                           int contextLength,
                                           int contextStart,
                                           int beginPad) {
  CHECK_NOTNULL(outputGrad);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(weightGrad);

  int threadsX = 32;
  int threadsY = 32;
  int blocksX = totalPad * ((weightDim + threadsX - 1) / threadsX);
  dim3 threads(threadsX, threadsY);
  dim3 grid(blocksX, 1);

  KeContextProjectionBackwardWeight<32, 32>
    <<< grid, threads, 0, STREAM_DEFAULT >>>
    (outputGrad, sequence, weightGrad, numSequences, weightDim,
     contextLength, contextStart, beginPad);
  CHECK_SYNC("hl_context_projection_backward_weight failed");
}

template<int blockDimX, int blockDimY, int gridDimX, bool AddRow>
__global__ void KeMatrixAddRows(real* output,
                                real* table,
                                int* ids,
                                int numSamples,
                                int tableSize,
                                int dim) {
  int idx = threadIdx.x;
  int idy = threadIdx.y;
  int sampleId = blockIdx.x + idy * gridDimX;

  while (sampleId < numSamples) {
    int tableId = ids[sampleId];
    if ((0 <= tableId) && (tableId < tableSize)) {
      real *outputData = output + sampleId * dim;
      real *tableData = table + tableId * dim;
      for (int i = idx; i < dim; i += blockDimX) {
        if (AddRow == 0) {
          outputData[i] += tableData[i];
        } else {
365
          paddle::paddleAtomicAdd(&tableData[i], outputData[i]);
Z
zhangjinchao01 已提交
366 367 368 369 370 371 372 373 374 375 376
        }
      }
    }
    sampleId += blockDimY*gridDimX;
  }
}

template<int blockDimX, int blockDimY, int gridDimX, bool seq2batch, bool isAdd>
__global__
void KeSequence2Batch(real *batch,
                      real *sequence,
377
                      const int *batchIndex,
Z
zhangjinchao01 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
                      int seqWidth,
                      int batchCount) {
  int idx = threadIdx.x;
  int idy = threadIdx.y;
  int id = blockIdx.x + idy * gridDimX;
  while (id < batchCount) {
    int seqId = batchIndex[id];
    real* batchData = batch + id*seqWidth;
    real* seqData = sequence + seqId*seqWidth;
    for (int i = idx; i < seqWidth; i += blockDimX) {
      if (seq2batch) {
        if (isAdd) {
          batchData[i] += seqData[i];
        } else {
          batchData[i] = seqData[i];
        }
      } else {
        if (isAdd) {
          seqData[i] += batchData[i];
        } else {
          seqData[i] = batchData[i];
        }
      }
    }
    id += blockDimY*gridDimX;
  }
}

void hl_sequence2batch_copy(real *batch,
                            real *sequence,
408
                            const int *batchIndex,
Z
zhangjinchao01 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
                            int seqWidth,
                            int batchCount,
                            bool seq2batch) {
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(batch);
  CHECK_NOTNULL(batchIndex);

  dim3 threads(128, 8);
  dim3 grid(8, 1);
  if (seq2batch) {
    KeSequence2Batch<128, 8, 8, 1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  } else {
    KeSequence2Batch<128, 8, 8, 0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  }
  CHECK_SYNC("hl_sequence2batch_copy failed");
}

void hl_sequence2batch_add(real *batch,
                           real *sequence,
                           int *batchIndex,
                           int seqWidth,
                           int batchCount,
                           bool seq2batch) {
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(batch);
  CHECK_NOTNULL(batchIndex);

  dim3 threads(128, 8);
  dim3 grid(8, 1);
  if (seq2batch) {
    KeSequence2Batch<128, 8, 8, 1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  } else {
    KeSequence2Batch<128, 8, 8, 0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>
      (batch, sequence, batchIndex, seqWidth, batchCount);
  }
  CHECK_SYNC("hl_sequence2batch_add failed");
}

450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
template<bool normByTimes, bool seq2batch>
__global__
void KeSequence2BatchPadding(real* batch,
                             real* sequence,
                             const int* sequenceStartPositions,
                             const size_t sequenceWidth,
                             const size_t maxSequenceLength,
                             const size_t numSequences) {
  int batchIdx = blockIdx.y;
  int sequenceStart = sequenceStartPositions[batchIdx];
  int sequenceLength = sequenceStartPositions[batchIdx + 1] - sequenceStart;

  int sequenceIdx = blockIdx.x * blockDim.y + threadIdx.y;
  int batchBaseIdx = (sequenceIdx * numSequences + batchIdx) * sequenceWidth;
  int sequenceBaseIdx = (sequenceStart + sequenceIdx) * sequenceWidth;

  if (sequenceIdx < sequenceLength) {
    if (seq2batch) {
      /* sequence -> batch */
      if (normByTimes) {
        real scale = 1.0f / (real)sequenceLength;
        for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
          batch[batchBaseIdx + i] = scale * sequence[sequenceBaseIdx + i];
        }
      } else {
        for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
          batch[batchBaseIdx + i] = sequence[sequenceBaseIdx + i];
        }
      }
    } else {
      /* batch -> sequence */
      if (normByTimes) {
        real scale = 1.0f / (real)sequenceLength;
        for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
          sequence[sequenceBaseIdx + i] = scale * batch[batchBaseIdx + i];
        }
      } else {
        for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
          sequence[sequenceBaseIdx + i] = batch[batchBaseIdx + i];
        }
      }
    }
  } else if (sequenceIdx < maxSequenceLength) {
    if (seq2batch) {
      /* sequence -> batch */
      for (int i = threadIdx.x; i < sequenceWidth; i += blockDim.x) {
        batch[batchBaseIdx + i] = 0;
      }
    }
  }
}

void hl_sequence2batch_copy_padding(real* batch,
                                    real* sequence,
                                    const int* sequenceStartPositions,
                                    const size_t sequenceWidth,
                                    const size_t maxSequenceLength,
                                    const size_t numSequences,
                                    bool normByTimes,
                                    bool seq2batch) {
  CHECK_NOTNULL(batch);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(sequenceStartPositions);

  if (!normByTimes && numSequences == 1) {
    size_t elementCount = maxSequenceLength * sequenceWidth;
    if (seq2batch) {
      /* sequence -> batch */
      hl_memcpy_device2device(batch, sequence, sizeof(real) * elementCount);
    } else {
      /* batch -> sequence */
      hl_memcpy_device2device(sequence, batch, sizeof(real) * elementCount);
    }
    return;
  }

  const int CUDA_BLOCK_SIZE = 512;

  /* At least use 32 threads to copy sequenceWidth elements,
     and at least 8 elements for each thread. */
  int blockDimX = ((((sequenceWidth + 7) >> 3) + 31) >> 5) << 5;
  blockDimX = (blockDimX < CUDA_BLOCK_SIZE) ? blockDimX : CUDA_BLOCK_SIZE;

  int blockDimY = CUDA_BLOCK_SIZE / blockDimX;
  dim3 threads(blockDimX, blockDimY);

  int gridDimX = (maxSequenceLength * blockDimX + CUDA_BLOCK_SIZE - 1) /
      CUDA_BLOCK_SIZE;
  int gridDimY = numSequences;
  dim3 grid(gridDimX, gridDimY);

  if (seq2batch) {
    /* sequence -> batch */
    if (normByTimes) {
      KeSequence2BatchPadding<1, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    } else {
      KeSequence2BatchPadding<0, 1><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    }
  } else {
    /* batch -> sequence */
    if (normByTimes) {
      KeSequence2BatchPadding<1, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    } else {
      KeSequence2BatchPadding<0, 0><<< grid, threads, 0, STREAM_DEFAULT >>>(
              batch, sequence, sequenceStartPositions,
              sequenceWidth, maxSequenceLength, numSequences);
    }
  }

  CHECK_SYNC("hl_sequence2batch_copy_padding failed");
}

Z
zhangjinchao01 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
__device__ inline float my_rsqrt(float x) {
  return rsqrtf(x);
}

__device__ inline double my_rsqrt(double x) {
  return rsqrt(x);
}

__global__ void KeSequenceAvgForward(real* dst,
                                     real* src,
                                     const int* starts,
                                     int height,
                                     int width,
                                     const int mode) {
  int gid = blockIdx.x * blockDim.x + threadIdx.x;
  int row = gid / width;
  int col = gid % width;

  if (gid < height * width) {
    int start = starts[row];
    int end = starts[row + 1];
    int seqLength = end - start;
    if (seqLength == 0) return;
    real sum = 0.0;
    for (int i = 0; i < seqLength; i++) {
      sum += src[(start + i) * width + col];
    }
    sum = mode == 1 ? sum :
        (mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength));
    dst[row * width + col] = sum;
  }
}

void hl_sequence_avg_forward(real* dst,
                             real* src,
                             const int* starts,
                             int height,
                             int width,
                             const int mode) {
  CHECK_NOTNULL(dst);
  CHECK_NOTNULL(src);
  CHECK_NOTNULL(starts);

  int block = 512;
  int grid = DIVUP(width * height, 512);

  CHECK(mode == 0 || mode == 1 || mode == 2)
    << "mode error in hl_sequence_avg_forward!";

  KeSequenceAvgForward<<< grid, block, 0, STREAM_DEFAULT >>>
           (dst, src, starts, height, width, mode);
  CHECK_SYNC("hl_sequence_avg_forward failed");
}