lstm_gpu_kernel.h 8.7 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#include "paddle/operators/math/detail/activation_functions.h"
D
dangqingqing 已提交
17 18
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
19
#include "paddle/platform/device_context.h"
D
dangqingqing 已提交
20

21
#include <type_traits>
22

D
dangqingqing 已提交
23 24 25 26 27 28 29 30 31 32
namespace paddle {
namespace operators {
namespace math {
namespace detail {

/*
 * threads(framePerBlock, batchPerBlock)
 * grid(frameBlocks, batchBlocks)
 */
template <class T, class Op, bool isBatch>
33
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
34 35 36
                              int batchSize, activation_mode_t active_node,
                              activation_mode_t active_gate,
                              activation_mode_t active_state) {
D
dangqingqing 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
  const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frameIdx >= frameSize) return;

  int batchIdx = 0;
  if (isBatch) {
    batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batchIdx >= batchSize) return;
    value.gateValue += batchIdx * frameSize * 4;
    value.outputValue += batchIdx * frameSize;
    value.stateValue += batchIdx * frameSize;
    value.stateActiveValue += batchIdx * frameSize;
  }

  T rState;
  T rPrevState = 0;
  T rStateAtv;
  T rOut;
  T rValueIn;
  T rValueIg;
  T rValueFg;
  T rValueOg;
  T rCheckI = value.checkIg[frameIdx];
  T rCheckF = value.checkFg[frameIdx];
  T rCheckO = value.checkOg[frameIdx];

  rValueIn = value.gateValue[frameIdx];
  rValueIg = value.gateValue[frameIdx + frameSize];
  rValueFg = value.gateValue[frameIdx + frameSize * 2];
  rValueOg = value.gateValue[frameIdx + frameSize * 3];

  if (value.prevStateValue) {
    if (isBatch) value.prevStateValue += batchIdx * frameSize;
    rPrevState = value.prevStateValue[frameIdx];
  }

  op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
73
     rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state);
D
dangqingqing 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

  value.gateValue[frameIdx] = rValueIn;
  value.gateValue[frameIdx + frameSize] = rValueIg;
  value.gateValue[frameIdx + frameSize * 2] = rValueFg;
  value.gateValue[frameIdx + frameSize * 3] = rValueOg;

  value.stateValue[frameIdx] = rState;
  value.stateActiveValue[frameIdx] = rStateAtv;
  value.outputValue[frameIdx] = rOut;
}

/*
 * threads(framePerBlock, batchPerBlock)
 * grid(frameBlocks, batchBlocks)
 */
template <class T, class Op, bool isBatch>
90 91
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
                               LstmMetaGrad<T> grad, int frameSize,
92 93 94
                               int batchSize, activation_mode_t active_node,
                               activation_mode_t active_gate,
                               activation_mode_t active_state) {
D
dangqingqing 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frameIdx >= frameSize) return;

  int batchIdx = 0;
  if (isBatch) {
    batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batchIdx >= batchSize) return;
    value.gateValue += batchIdx * frameSize * 4;
    value.stateValue += batchIdx * frameSize;
    value.stateActiveValue += batchIdx * frameSize;
    grad.gateGrad += batchIdx * frameSize * 4;
    grad.stateGrad += batchIdx * frameSize;
    grad.outputGrad += batchIdx * frameSize;
  }

  T rValueIn;
  T rValueIg;
  T rValueFg;
  T rValueOg;
  T rGradIn;
  T rGradIg;
  T rGradFg;
  T rGradOg;
  T rPrevState = 0;
  T rPrevStateGrad;
  T rState;
  T rStateGrad;
  T rStateAtv;
  T rOutputGrad;
  T rCheckI = value.checkIg[frameIdx];
  T rCheckF = value.checkFg[frameIdx];
  T rCheckO = value.checkOg[frameIdx];
  T rCheckIGrad;
  T rCheckFGrad;
  T rCheckOGrad;

  rValueIn = value.gateValue[frameIdx];
  rValueIg = value.gateValue[frameIdx + frameSize];
  rValueFg = value.gateValue[frameIdx + frameSize * 2];
  rValueOg = value.gateValue[frameIdx + frameSize * 3];
  rState = value.stateValue[frameIdx];
  rStateAtv = value.stateActiveValue[frameIdx];
  rOutputGrad = grad.outputGrad[frameIdx];
  rStateGrad = grad.stateGrad[frameIdx];

  if (value.prevStateValue) {
    if (isBatch) value.prevStateValue += batchIdx * frameSize;
    rPrevState = value.prevStateValue[frameIdx];
  }

  op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
     rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
147 148
     rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
     active_node, active_gate, active_state);
D
dangqingqing 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180

  grad.gateGrad[frameIdx] = rGradIn;
  grad.gateGrad[frameIdx + frameSize] = rGradIg;
  grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
  grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
  grad.stateGrad[frameIdx] = rStateGrad;
  if (grad.prevStateGrad) {
    if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
    grad.prevStateGrad[frameIdx] = rPrevStateGrad;
  }

  if (isBatch) {
    if (value.prevStateValue) {
      if (grad.checkIgGrad)
        paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx,
                                        rCheckIGrad);
      if (grad.checkFgGrad)
        paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx,
                                        rCheckFGrad);
    }
    if (grad.checkOgGrad)
      paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad);
  } else {
    if (value.prevStateValue) {
      if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
      if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
    }
    if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
  }
}

template <class T, class Op>
181 182
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
                      LstmMetaValue<T> value, int frameSize, int batchSize,
D
dangqingqing 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
                      activation_mode_t active_node,
                      activation_mode_t active_gate,
                      activation_mode_t active_state) {
  dim3 threads;
  dim3 grid;
  if (batchSize == 1) {
    int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
    int frameBlocks = (frameSize + 1024 - 1) / 1024;
    threads = dim3(framePerBlock, 1);
    grid = dim3(frameBlocks, 1);
  } else {
    /* framePerBlock = 32 batchPerBlock = 32 */
    threads = dim3(32, 32);
    grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
  }

199 200
  auto stream =
      reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
D
dangqingqing 已提交
201 202
  if (batchSize == 1) {
    KeLstmForward<T, Op,
203
                  /* isBatch= */ false><<<grid, threads, 0, stream>>>(
204 205
        op, value, frameSize, batchSize, active_node, active_gate,
        active_state);
D
dangqingqing 已提交
206 207
  } else {
    KeLstmForward<T, Op,
208
                  /* isBatch= */ true><<<grid, threads, 0, stream>>>(
209 210
        op, value, frameSize, batchSize, active_node, active_gate,
        active_state);
D
dangqingqing 已提交
211 212 213 214
  }
}

template <class T, class Op>
215 216 217 218
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
                       LstmMetaValue<T> value, LstmMetaGrad<T> grad,
                       int frameSize, int batchSize,
                       activation_mode_t active_node,
D
dangqingqing 已提交
219 220 221 222 223 224 225 226 227 228
                       activation_mode_t active_gate,
                       activation_mode_t active_state) {
  dim3 threads;
  dim3 grid;
  if (batchSize == 1) {
    int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
    int frameBlocks = (frameSize + 1024 - 1) / 1024;
    threads = dim3(framePerBlock, 1);
    grid = dim3(frameBlocks, 1);
  } else {
229
    /* framePerBlock = 32 batchPerBlock = 16 */
230
    threads = dim3(32, 16);
231
    grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16);
D
dangqingqing 已提交
232 233
  }

234 235
  auto stream =
      reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
D
dangqingqing 已提交
236 237
  if (batchSize == 1) {
    KeLstmBackward<T, Op,
238
                   /* isBatch= */ false><<<grid, threads, 0, stream>>>(
239 240
        op, value, grad, frameSize, batchSize, active_node, active_gate,
        active_state);
D
dangqingqing 已提交
241 242
  } else {
    KeLstmBackward<T, Op,
243
                   /* isBatch= */ true><<<grid, threads, 0, stream>>>(
244 245
        op, value, grad, frameSize, batchSize, active_node, active_gate,
        active_state);
D
dangqingqing 已提交
246
  }
247 248 249 250

  cudaStreamSynchronize(stream);
  // TODO(qingqing): Add cuda error check for each kernel.
  cudaError_t err = cudaGetLastError();
D
dangqingqing 已提交
251
  PADDLE_ENFORCE(err, cudaGetErrorString(err));
D
dangqingqing 已提交
252 253 254 255 256 257
}

}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle