hl_cuda_lstm.cu 24.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#include "hl_base.h"
#include "hl_cuda_cublas.h"
#include "hl_device_functions.cuh"
#include "hl_activation_functions.h"
#include "paddle/utils/Logging.h"

typedef hppl::Active<real>::forward  t_forward;
typedef hppl::Active<real>::backward t_backward;

bool hl_lstm_sequence_parallel(int frameSize) {
  if (frameSize == 32 || frameSize == 64) {
    return true;
  } else {
    return false;
  }
}

class frameValue {
public:
  real *value_;
  __device__ frameValue(real *value) : value_(value) {}
  template <int reversed, int frameSize>
  __device__ inline void init(int start, int length, int idx) {
    if (reversed == 0) {
      value_ += start * frameSize + idx;
    } else {
      value_ += (start + length - 1) * frameSize + idx;
    }
  }
  __device__ inline real *getPtr() const {return value_;}
  __device__ inline real getValue() {return *value_;}
  __device__ inline void setValue(real value) {*value_ = value;}
  template <int reversed, int frameSize>
  __device__ inline void nextFrame() {
    if (reversed == 0) {
      value_ += frameSize;
    } else {
      value_ -= frameSize;
    }
  }
};

__device__ __forceinline__
void ptx_sync(const int id, const int barriers) {
  asm volatile("bar.sync %0, %1;" : : "r"(id), "r"(barriers) : "memory");
}

__device__ __forceinline__
void ptx_arrive(const int id, const int barriers) {
  asm volatile("bar.arrive %0, %1;" : : "r"(id), "r"(barriers) : "memory");
}

template<int valueSize, int frameSize>
__device__ __forceinline__ real
forward_sequence(real value,
                 real *shValue,
                 real *state,
                 real *preOutput,
                 real *output,
                 real check,
                 int index,
                 t_forward activeNode,
                 t_forward activeGate,
                 t_forward activeState) {
  real out;
  real prevOut;
  real state_r;
  const int idx = index % frameSize;
  const int idy = index / frameSize;
  // assert(index < valueSize);

  if (idy == 0) {
    value = activeNode(value);
    shValue[index] = value;
  }
  if (idy == 1 || idy == 2) {
    state_r = state[idx];
    value += state_r * check;
    value = activeGate(value);
    shValue[index] = value;
  }
  ptx_sync(1, valueSize);
  if (idy == 3) {
    state_r = state[idx];
    state_r = state_r * shValue[idx + frameSize * 2];
    state_r += shValue[idx] * shValue[idx + frameSize];
    state[idx] = state_r;
    ptx_arrive(2, frameSize * 2);
    value += state_r * check;
    value = activeGate(value);
    shValue[index] = value;
    ptx_sync(3, frameSize * 2);
    prevOut = preOutput[idx];
    out = prevOut * value;
    output[idx] = out;
  }
  if (idy == 0) {
    ptx_sync(2, frameSize * 2);
    prevOut = state[idx];
     prevOut = activeState(prevOut);
    preOutput[idx] = prevOut;
    ptx_arrive(3, frameSize * 2);
  }
  return value;
}

#define     OUTPUT_BARRIER_ID               10
#define     OUTPUT_BARRIER_ID2              11
template<int valueSize, int frameSize, int reversed,
         int computeThreads, int blockSize>
__global__ void KeLstmForward(real *gateValue,
                              real *state,
                              real *output,
                              real *preOutput,
                              real *checkIg,
                              real *checkFg,
                              real *checkOg,
                              real *weight,
                              const int *starts,
                              hl_activation_mode_t active_node,
                              hl_activation_mode_t active_gate,
                              hl_activation_mode_t active_state) {
  __shared__ real shValue[valueSize];
  __shared__ real shState[frameSize];
  __shared__ real shPrevOutput[frameSize];
  __shared__ real shOutput[frameSize];

  const int index = threadIdx.x;
  int start = starts[blockIdx.x];
  int length = starts[blockIdx.x + 1] - start;

  /* init */
  real check;
  real value;
  frameValue frameGate(gateValue);
  frameValue frameState(state);
  frameValue frameOutput(output);
  frameValue framePreOutput(preOutput);
  if (index < valueSize) {
    const int idx = index % frameSize;
    const int idy = index / frameSize;
    frameGate.init<reversed, valueSize>(start, length, index);
    value = frameGate.getValue();
    if (idy == 0) {
      shState[idx] = 0.0;
    } else if (idy == 1) {
      check = checkIg[idx];
    } else if (idy == 2) {
      check = checkFg[idx];
    } else if (idy == 3) {
      check = checkOg[idx];
    }

    if (idy == 3) {
      frameState.init<reversed, frameSize>(start, length, idx);
      frameOutput.init<reversed, frameSize>(start, length, idx);
      framePreOutput.init<reversed, frameSize>(start, length, idx);
    }

    ptx_sync(1, valueSize);
  }

  for (int i = 0; i < length; ++i) {
    if (index < valueSize) {
      if (valueSize == 128) {
        if (i != 0) {
          ptx_sync(OUTPUT_BARRIER_ID2, blockSize);
          value += shValue[index];
        }
      }
      value = forward_sequence<valueSize, frameSize>(
        value, shValue, shState, shPrevOutput, shOutput, check, index,
        hppl::gpu::forward[active_node],
        hppl::gpu::forward[active_gate],
        hppl::gpu::forward[active_state]);
      const int idx = index % frameSize;
      const int idy = index / frameSize;
      if (valueSize == 128) {
        if (idy == 3) {
          ptx_arrive(OUTPUT_BARRIER_ID, frameSize + 128);
        }
      }
      if (valueSize == 256) {
        ptx_sync(OUTPUT_BARRIER_ID, valueSize);
      }
      frameGate.setValue(value);
      if (idy == 3) {
        frameState.setValue(shState[idx]);
        frameOutput.setValue(shOutput[idx]);
        framePreOutput.setValue(shPrevOutput[idx]);
        frameState.nextFrame<reversed, frameSize>();
        frameOutput.nextFrame<reversed, frameSize>();
        framePreOutput.nextFrame<reversed, frameSize>();
      }
      if (i != length - 1) {
        frameGate.nextFrame<reversed, valueSize>();
        value = frameGate.getValue();
      }
    }
    if (i != length - 1) {
      if (valueSize == 128) {
        if (valueSize <= index) {
          real B_r[frameSize];
          const int computeIdx = index - valueSize;
          if (i == 0) {
            #pragma unroll
            for (int n = 0; n < frameSize; n++) {
              B_r[n] = weight[n * valueSize + computeIdx];
            }
          }
          ptx_sync(OUTPUT_BARRIER_ID, frameSize + 128);
          real A_r[frameSize];
          for (int n = 0; n < frameSize; n++) {
            A_r[n] = shOutput[n];
          }
          real sum = 0.0f;
          for (int n = 0; n < frameSize; n++) {
            sum += A_r[n]*B_r[n];
          }
          shValue[computeIdx] = sum;
          ptx_arrive(OUTPUT_BARRIER_ID2, blockSize);
        }
      }
      if (valueSize == 256) {
        real B_r[frameSize];
        if (i == 0) {
          #pragma unroll
          for (int n = 0; n < frameSize; n++) {
            B_r[n] = weight[n * valueSize + index];
          }
        }
        real sum = 0.0f;
        for (int n = 0; n < frameSize; n++) {
          sum += shOutput[n]*B_r[n];
        }
        value += sum;
      }
    }
  }
}

void hl_lstm_parallel_forward(real *gateValue,
                              real *stateValue,
                              real *preOutputValue,
                              real *outputValue,
                              real *checkIg,
                              real *checkFg,
                              real *checkOg,
                              real *weight,
                              const int *sequence,
                              int frameSize,
                              int numSequences,
                              bool reversed,
                              hl_activation_mode_t active_node,
                              hl_activation_mode_t active_gate,
                              hl_activation_mode_t active_state) {
  CHECK(frameSize == 32 || frameSize == 64);
  dim3 grid(numSequences, 1);
  if (!reversed) {
    if (frameSize == 32) {
      KeLstmForward<128, 32, 0, 128, 256>
               <<<grid, 256, 0, STREAM_DEFAULT>>>
               (gateValue, stateValue, outputValue, preOutputValue,
               checkIg, checkFg, checkOg, weight, sequence,
               active_node, active_gate, active_state);
    } else if (frameSize == 64) {
      KeLstmForward<256, 64, 0, 256, 256>
               <<<grid, 256, 0, STREAM_DEFAULT>>>
               (gateValue, stateValue, outputValue, preOutputValue,
               checkIg, checkFg, checkOg, weight, sequence,
               active_node, active_gate, active_state);
    }
  } else {
    if (frameSize == 32) {
      KeLstmForward<128, 32, 1, 128, 256>
               <<<grid, 256, 0, STREAM_DEFAULT>>>
               (gateValue, stateValue, outputValue, preOutputValue,
               checkIg, checkFg, checkOg, weight, sequence,
               active_node, active_gate, active_state);
    } else if (frameSize == 64) {
      KeLstmForward<256, 64, 1, 256, 256>
               <<<grid, 256, 0, STREAM_DEFAULT>>>
               (gateValue, stateValue, outputValue, preOutputValue,
               checkIg, checkFg, checkOg, weight, sequence,
               active_node, active_gate, active_state);
    }
  }
  CHECK_SYNC("hl_lstm_parallel_forward failed");
}

__device__ __forceinline__
void transpose_32x32(real a[], const int idx) {
  int addr = idx % 32;
  #pragma unroll
  for (int k = 1; k < 32; k++) {
    // rSrc[k] = __shfl(rSrc[k], (threadIdx.x + k) % 32, 32);
    addr = __shfl(addr, (idx + 1) % 32, 32);
    a[k] = __shfl(a[k], addr, 32);
  }

  #pragma unroll
  for (int tid = 0; tid < 31; tid++) {
    real tmp = (idx > tid) ? a[0] : a[1];
    #pragma unroll
    for (int k = 31; k > 0; k--) {
      a[(k + 1) % 32] = (idx > tid) ? a[k] : a[(k + 1) % 32];
    }
    a[1] = tmp;
  }

  addr = (32 - idx) % 32;
  #pragma unroll
  for (int k = 0; k < 32; k++) {
    a[k] = __shfl(a[k], addr, 32);
    addr = __shfl(addr, (idx + 31) % 32, 32);
  }
}

template<int valueSize, int frameSize>
__device__ void
backward_sequence(real rGateValue,
                  real rOutputGrad,
                  real rPreOutputValue,
                  real &rGateGrad,
                  real &rStateGrad,
                  real *shStateGrad,
                  real *shStateValue,
                  real *shGateValue,
                  real rCheck,
                  real &rGateValuePrev,
                  int index,
                  t_backward activeNode,
                  t_backward activeGate,
                  t_backward activeState) {
  const int frameIdx = index % frameSize;
  const int frameIdy = index / frameSize;
  if (frameIdy == 3) {
    real rPrevOutputGrad;
    rPrevOutputGrad = rOutputGrad * rGateValue;
    rStateGrad = activeState(rPrevOutputGrad, rPreOutputValue);
    rGateGrad = rOutputGrad * rPreOutputValue;
    rGateGrad = activeGate(rGateGrad, rGateValue);
    rStateGrad += rGateGrad * rCheck;
    shStateGrad[index] = rStateGrad;
    ptx_arrive(3, valueSize);
  } else if (frameIdy == 1) {
    shGateValue[frameIdx + frameSize] = rGateValue;
    rStateGrad = rGateGrad * rCheck;
    shStateGrad[index] = rStateGrad;
    ptx_sync(3, valueSize);
    rStateGrad += shStateGrad[frameIdx + frameSize *2];
    rStateGrad += shStateGrad[frameIdx + frameSize *3];
    rGateGrad = rStateGrad * shGateValue[frameIdx];
    rGateGrad = activeGate(rGateGrad, rGateValue);
  } else if (frameIdy == 2) {
    rStateGrad = rStateGrad * rGateValuePrev;
    rStateGrad += rGateGrad * rCheck;
    shStateGrad[index] = rStateGrad;
    ptx_sync(3, valueSize);
    rStateGrad += shStateGrad[frameIdx + frameSize];
    rStateGrad += shStateGrad[frameIdx + frameSize *3];
    rGateValuePrev = rGateValue;
    rGateGrad = rStateGrad * shStateValue[frameIdx];
    rGateGrad = activeGate(rGateGrad, rGateValue);
  } else if (frameIdy == 0) {
    shGateValue[frameIdx] = rGateValue;
    ptx_sync(3, valueSize);
    rStateGrad = shStateGrad[frameIdx + frameSize];
    rStateGrad += shStateGrad[frameIdx + frameSize *2];
    rStateGrad += shStateGrad[frameIdx + frameSize *3];
    rGateGrad = rStateGrad * shGateValue[frameIdx + frameSize];
    rGateGrad = activeNode(rGateGrad, rGateValue);
  }
}

template<int valueSize, int frameSize>
__device__ void load_weight(real rWeight[], real *weight, const int index) {
  if (valueSize == 128) {
    weight += index;
    #pragma unroll
    for (int n = 0; n < frameSize; n++) {
      rWeight[n] = weight[n*valueSize];
    }
    transpose_32x32(rWeight, index % 32);
  }
  if (valueSize == 256) {
    int id = (index / 32) % 2;
    weight += index - id * 32 + id * 32 * valueSize;
    #pragma unroll
    for (int n = 0; n < 32; n++) {
      rWeight[n] = weight[n*valueSize];
      rWeight[n + 32] = weight[n*valueSize + 32];
    }
    transpose_32x32(rWeight, index % 32);
    transpose_32x32(&rWeight[32], index % 32);
  }
}

template<int valueSize, int frameSize, int reversed>
__global__ void KeLstmBackward(real *gateValue,
                               real *gateGrad,
                               real *stateValue,
                               real *stateGrad,       /* do not need save */
                               real *preOutputValue,
                               real *preOutputGrad,   /* do not need save */
                               real *checkIg,
                               real *checkIgGrad,
                               real *checkFg,
                               real *checkFgGrad,
                               real *checkOg,
                               real *checkOgGrad,
                               real *outputGrad,
                               real *weightValue,
                               const int *starts,
                               hl_activation_mode_t active_node,
                               hl_activation_mode_t active_gate,
                               hl_activation_mode_t active_state) {
  __shared__ real shGateValue[valueSize];
  __shared__ real shStateGrad[valueSize];
  __shared__ real shStateValue[frameSize];
  __shared__ real shGateGrad[4][frameSize];
  __shared__ real shOutputGrad[4][frameSize];
  const int index = threadIdx.x;
  int start = starts[blockIdx.x];
  int length = starts[blockIdx.x + 1] - start;

  const int frameIdx = index % frameSize;
  const int frameIdy = index / frameSize;
  real rCheck;
  real rCheckGrad;
  real rGateGrad;
  real rStateGrad;
  real rGateValuePrev;
  real rPreOutputValue;
  real rOutputGrad;
  real rGateValue;
  real rStateValue;

  frameValue frameGateValue(gateValue);
  frameValue frameGateGrad(gateGrad);
  frameValue framePreOutputValue(preOutputValue);
  frameValue frameStateValue(stateValue);
  frameValue frameOutputGrad(outputGrad);
  if (frameIdy == 0) {
  } else if (frameIdy == 1) {
    rCheck = checkIg[frameIdx];
  } else if (frameIdy == 2) {
    rCheck = checkFg[frameIdx];
    rGateValuePrev = 0.0;
    rStateGrad = 0.0;
  } else if (frameIdy == 3) {
    rCheck = checkOg[frameIdx];
    framePreOutputValue.init<!reversed, frameSize>(start, length, frameIdx);
    frameOutputGrad.init<!reversed, frameSize>(start, length, frameIdx);
    rOutputGrad = frameOutputGrad.getValue();
    rPreOutputValue = framePreOutputValue.getValue();
    frameStateValue.init<!reversed, frameSize>(start, length, frameIdx);
    rStateValue = frameStateValue.getValue();
  }

  frameGateValue.init<!reversed, valueSize>(start, length, index);
  frameGateGrad.init<!reversed, valueSize>(start, length, index);
  rGateValue = frameGateValue.getValue();
  rGateGrad = 0.0;
  rCheckGrad = 0.0;

  real B_r[frameSize];
  load_weight<valueSize, frameSize>(B_r, weightValue, index);

  for (int i = 0; i < length; ++i) {
    if (frameIdy == 3) {
      if (i != length -1) {
        frameStateValue.nextFrame<!reversed, frameSize>();
        shStateValue[frameIdx] = frameStateValue.getValue();
      } else {
        shStateValue[frameIdx] = 0.0;
      }
    }
    backward_sequence<valueSize, frameSize>(
        rGateValue, rOutputGrad, rPreOutputValue, rGateGrad,
        rStateGrad, shStateGrad, shStateValue, shGateValue,
        rCheck, rGateValuePrev, index,
        hppl::gpu::backward[active_node],
        hppl::gpu::backward[active_gate],
        hppl::gpu::backward[active_state]);
    if (frameIdy == 3) {
      rCheckGrad += rGateGrad * rStateValue;
      rStateValue = shStateValue[frameIdx];
    }

    frameGateGrad.setValue(rGateGrad);
    frameGateGrad.nextFrame<!reversed, valueSize>();

    if (i != length - 1) {
      if (frameIdy == 3) {
        framePreOutputValue.nextFrame<!reversed, frameSize>();
        rPreOutputValue = framePreOutputValue.getValue();
        frameOutputGrad.nextFrame<!reversed, frameSize>();
        rOutputGrad = frameOutputGrad.getValue();
      } else if (frameIdy == 2) {
        rCheckGrad += rGateGrad * shStateValue[frameIdx];
      } else if (frameIdy == 1) {
        rCheckGrad += rGateGrad * shStateValue[frameIdx];
      }

      frameGateValue.nextFrame<!reversed, valueSize>();
      rGateValue = frameGateValue.getValue();
      shGateGrad[frameIdy][frameIdx] = rGateGrad;
      if (valueSize == 128) {
        real sum = 0.0f;
        #pragma unroll
        for (int n = 0; n < frameSize; n++) {
          sum += shGateGrad[frameIdy][n]*B_r[n];
        }
        if (frameIdy == 3) {
          rOutputGrad += sum;
        } else {
          shOutputGrad[frameIdy][frameIdx] = sum;
        }
      }
      if (valueSize == 256) {
        ptx_sync(5, valueSize);
        real A_r[frameSize];
        for (int n = 0; n < frameSize; n++) {
          A_r[n] = shGateGrad[frameIdy][n];
        }
        real sum = 0.0f;
        for (int n = 0; n < frameSize; n++) {
          sum += A_r[n]*B_r[n];
        }
        if (frameIdy == 3) {
          rOutputGrad += sum;
        } else {
          shOutputGrad[frameIdy][frameIdx] = sum;
        }
      }

      if (frameIdy == 3) {
        ptx_sync(6, valueSize);
        #pragma unroll
        for (int i = 0; i < 3; i ++) {
          rOutputGrad += shOutputGrad[i][frameIdx];
        }
      } else {
        ptx_arrive(6, valueSize);
      }
    }
  }

  /* TODO: Temporary save & merger in another kernel */
  if (frameIdy == 1) {
567
    if (checkIgGrad) paddle::paddleAtomicAdd(checkIgGrad+frameIdx, rCheckGrad);
Z
zhangjinchao01 已提交
568
  } else if (frameIdy == 2) {
569
    if (checkFgGrad) paddle::paddleAtomicAdd(checkFgGrad+frameIdx, rCheckGrad);
Z
zhangjinchao01 已提交
570
  } else if (frameIdy == 3) {
571
    if (checkOgGrad) paddle::paddleAtomicAdd(checkOgGrad+frameIdx, rCheckGrad);
Z
zhangjinchao01 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
  }
}

void hl_lstm_parallel_backward_data(real *gateValue,
                                    real *gateGrad,
                                    real *stateValue,
                                    real *stateGrad,
                                    real *preOutputValue,
                                    real *preOutputGrad,
                                    real *outputGrad,
                                    real *checkIg,
                                    real *checkIgGrad,
                                    real *checkFg,
                                    real *checkFgGrad,
                                    real *checkOg,
                                    real *checkOgGrad,
                                    real *weight,
                                    const int *sequence,
                                    int frameSize,
                                    int numSequences,
                                    bool reversed,
                                    hl_activation_mode_t active_node,
                                    hl_activation_mode_t active_gate,
                                    hl_activation_mode_t active_state) {
  CHECK(frameSize == 32 || frameSize == 64 ||
        frameSize == 128 || frameSize == 256);
  dim3 grid(numSequences, 1);
  if (!reversed) {
    if (frameSize == 32) {
      KeLstmBackward<128, 32, 0><<<grid, 128, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    } else if (frameSize == 64) {
      KeLstmBackward<256, 64, 0><<<grid, 256, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    } else if (frameSize == 128) {
      KeLstmBackward<512, 128, 0><<<grid, 512, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    } else if (frameSize == 256) {
      KeLstmBackward<1024, 256, 0><<<grid, 1024, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    }
  } else {
    if (frameSize == 32) {
      KeLstmBackward<128, 32, 1><<<grid, 128, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    } else if (frameSize == 64) {
      KeLstmBackward<256, 64, 1><<<grid, 256, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    } else if (frameSize == 128) {
      KeLstmBackward<512, 128, 1><<<grid, 512, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    } else if (frameSize == 256) {
      KeLstmBackward<1024, 256, 1><<<grid, 1024, 0, STREAM_DEFAULT>>>
          (gateValue, gateGrad, stateValue, stateGrad, preOutputValue,
          preOutputGrad, checkIg, checkIgGrad, checkFg, checkFgGrad, checkOg,
          checkOgGrad, outputGrad, weight, sequence,
          active_node, active_gate, active_state);
    }
  }
  CHECK_SYNC("hl_lstm_parallel_backward_data");
}

template<int B_X, int B_Y>
__global__ void KeSetGradZero(real *gateGrad,
    const int *starts, int valueSize, int numSequences, bool reversed) {
  // const int tid = threadIdx.x;

  const int frameIdx = blockIdx.x * B_X + threadIdx.x;
  const int numSeqId = blockIdx.y * B_Y + threadIdx.y;

  if (numSeqId >= numSequences || frameIdx >= valueSize) return;

  if (!reversed) {
    int seqId = starts[numSeqId];
    gateGrad[seqId * valueSize + frameIdx] = 0.0;
  } else {
    int seqId = starts[numSeqId + 1] - 1;
    gateGrad[seqId * valueSize + frameIdx] = 0.0;
  }
}

void hl_lstm_parallel_backward_weight(real *weightGrad,
                                      real *outputValue,
                                      real *gateGrad,
                                      const int *sequence,
                                      int frameSize,
                                      int batchSize,
                                      int numSequences,
                                      bool reversed) {
  int valueSize = 4 * frameSize;
  dim3 threads(32, 32);
  dim3 grid((valueSize + 32 - 1) / 32, (numSequences + 32 - 1) / 32);
  KeSetGradZero<32, 32><<<grid, threads, 0, STREAM_DEFAULT>>>
           (gateGrad, sequence, valueSize, numSequences, reversed);

  if (!reversed) {
    hl_matrix_mul(outputValue,
      HPPL_OP_T, gateGrad + valueSize, HPPL_OP_N, weightGrad,
      frameSize, valueSize, batchSize - 1,
      1.0, 1.0);
  } else {
    hl_matrix_mul(outputValue + frameSize,
      HPPL_OP_T, gateGrad, HPPL_OP_N, weightGrad,
      frameSize, valueSize, batchSize - 1,
      1.0, 1.0);
  }
  CHECK_SYNC("hl_lstm_parallel_backward_weight");
}