lstm_gpu_kernel.h 8.0 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
D
dangqingqing 已提交
18 19
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
20
#include "paddle/platform/device_context.h"
D
dangqingqing 已提交
21

22 23
#include <glog/logging.h>

D
dangqingqing 已提交
24 25 26 27 28 29 30 31 32 33
namespace paddle {
namespace operators {
namespace math {
namespace detail {

/*
 * threads(framePerBlock, batchPerBlock)
 * grid(frameBlocks, batchBlocks)
 */
template <class T, class Op, bool isBatch>
34
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
35
                              int batchSize) {
D
dangqingqing 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
  const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frameIdx >= frameSize) return;

  int batchIdx = 0;
  if (isBatch) {
    batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batchIdx >= batchSize) return;
    value.gateValue += batchIdx * frameSize * 4;
    value.outputValue += batchIdx * frameSize;
    value.stateValue += batchIdx * frameSize;
    value.stateActiveValue += batchIdx * frameSize;
  }

  T rState;
  T rPrevState = 0;
  T rStateAtv;
  T rOut;
  T rValueIn;
  T rValueIg;
  T rValueFg;
  T rValueOg;
  T rCheckI = value.checkIg[frameIdx];
  T rCheckF = value.checkFg[frameIdx];
  T rCheckO = value.checkOg[frameIdx];

  rValueIn = value.gateValue[frameIdx];
  rValueIg = value.gateValue[frameIdx + frameSize];
  rValueFg = value.gateValue[frameIdx + frameSize * 2];
  rValueOg = value.gateValue[frameIdx + frameSize * 3];

  if (value.prevStateValue) {
    if (isBatch) value.prevStateValue += batchIdx * frameSize;
    rPrevState = value.prevStateValue[frameIdx];
  }

  op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
72
     rOut, rCheckI, rCheckF, rCheckO);
D
dangqingqing 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

  value.gateValue[frameIdx] = rValueIn;
  value.gateValue[frameIdx + frameSize] = rValueIg;
  value.gateValue[frameIdx + frameSize * 2] = rValueFg;
  value.gateValue[frameIdx + frameSize * 3] = rValueOg;

  value.stateValue[frameIdx] = rState;
  value.stateActiveValue[frameIdx] = rStateAtv;
  value.outputValue[frameIdx] = rOut;
}

/*
 * threads(framePerBlock, batchPerBlock)
 * grid(frameBlocks, batchBlocks)
 */
template <class T, class Op, bool isBatch>
89 90
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
                               LstmMetaGrad<T> grad, int frameSize,
91
                               int batchSize) {
D
dangqingqing 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frameIdx >= frameSize) return;

  int batchIdx = 0;
  if (isBatch) {
    batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batchIdx >= batchSize) return;
    value.gateValue += batchIdx * frameSize * 4;
    value.stateValue += batchIdx * frameSize;
    value.stateActiveValue += batchIdx * frameSize;
    grad.gateGrad += batchIdx * frameSize * 4;
    grad.stateGrad += batchIdx * frameSize;
    grad.outputGrad += batchIdx * frameSize;
  }

  T rValueIn;
  T rValueIg;
  T rValueFg;
  T rValueOg;
  T rGradIn;
  T rGradIg;
  T rGradFg;
  T rGradOg;
  T rPrevState = 0;
  T rPrevStateGrad;
  T rState;
  T rStateGrad;
  T rStateAtv;
  T rOutputGrad;
  T rCheckI = value.checkIg[frameIdx];
  T rCheckF = value.checkFg[frameIdx];
  T rCheckO = value.checkOg[frameIdx];
  T rCheckIGrad;
  T rCheckFGrad;
  T rCheckOGrad;

  rValueIn = value.gateValue[frameIdx];
  rValueIg = value.gateValue[frameIdx + frameSize];
  rValueFg = value.gateValue[frameIdx + frameSize * 2];
  rValueOg = value.gateValue[frameIdx + frameSize * 3];
  rState = value.stateValue[frameIdx];
  rStateAtv = value.stateActiveValue[frameIdx];
  rOutputGrad = grad.outputGrad[frameIdx];
  rStateGrad = grad.stateGrad[frameIdx];

  if (value.prevStateValue) {
    if (isBatch) value.prevStateValue += batchIdx * frameSize;
    rPrevState = value.prevStateValue[frameIdx];
  }

  op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
     rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
144
     rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad);
D
dangqingqing 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176

  grad.gateGrad[frameIdx] = rGradIn;
  grad.gateGrad[frameIdx + frameSize] = rGradIg;
  grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
  grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
  grad.stateGrad[frameIdx] = rStateGrad;
  if (grad.prevStateGrad) {
    if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
    grad.prevStateGrad[frameIdx] = rPrevStateGrad;
  }

  if (isBatch) {
    if (value.prevStateValue) {
      if (grad.checkIgGrad)
        paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx,
                                        rCheckIGrad);
      if (grad.checkFgGrad)
        paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx,
                                        rCheckFGrad);
    }
    if (grad.checkOgGrad)
      paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad);
  } else {
    if (value.prevStateValue) {
      if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
      if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
    }
    if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
  }
}

template <class T, class Op>
177 178
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
                      LstmMetaValue<T> value, int frameSize, int batchSize,
D
dangqingqing 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
                      activation_mode_t active_node,
                      activation_mode_t active_gate,
                      activation_mode_t active_state) {
  dim3 threads;
  dim3 grid;
  if (batchSize == 1) {
    int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
    int frameBlocks = (frameSize + 1024 - 1) / 1024;
    threads = dim3(framePerBlock, 1);
    grid = dim3(frameBlocks, 1);
  } else {
    /* framePerBlock = 32 batchPerBlock = 32 */
    threads = dim3(32, 32);
    grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
  }

195 196
  auto stream =
      reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
D
dangqingqing 已提交
197 198
  if (batchSize == 1) {
    KeLstmForward<T, Op,
199
                  /* isBatch= */ false><<<grid, threads, 0, stream>>>(
200
        op, value, frameSize, batchSize);
D
dangqingqing 已提交
201 202
  } else {
    KeLstmForward<T, Op,
203
                  /* isBatch= */ true><<<grid, threads, 0, stream>>>(
204
        op, value, frameSize, batchSize);
D
dangqingqing 已提交
205 206 207 208
  }
}

template <class T, class Op>
209 210 211 212
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
                       LstmMetaValue<T> value, LstmMetaGrad<T> grad,
                       int frameSize, int batchSize,
                       activation_mode_t active_node,
D
dangqingqing 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
                       activation_mode_t active_gate,
                       activation_mode_t active_state) {
  dim3 threads;
  dim3 grid;
  if (batchSize == 1) {
    int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
    int frameBlocks = (frameSize + 1024 - 1) / 1024;
    threads = dim3(framePerBlock, 1);
    grid = dim3(frameBlocks, 1);
  } else {
    /* framePerBlock = 32 batchPerBlock = 32 */
    threads = dim3(32, 32);
    grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
  }

228 229
  auto stream =
      reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
D
dangqingqing 已提交
230 231
  if (batchSize == 1) {
    KeLstmBackward<T, Op,
232
                   /* isBatch= */ false><<<grid, threads, 0, stream>>>(
233
        op, value, grad, frameSize, batchSize);
D
dangqingqing 已提交
234 235
  } else {
    KeLstmBackward<T, Op,
236
                   /* isBatch= */ true><<<grid, threads, 0, stream>>>(
237
        op, value, grad, frameSize, batchSize);
D
dangqingqing 已提交
238 239 240 241 242 243 244
  }
}

}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle