hl_gpu_lstm.cuh 9.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#ifndef HL_GPU_LSTM_CUH_
#define HL_GPU_LSTM_CUH_

#ifdef __NVCC__

X
Xin Pan 已提交
21
#include "paddle/legacy/utils/Logging.h"
Z
zhangjinchao01 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
#include "hl_device_functions.cuh"

/*
 * threads(framePerBlock, batchPerBlock)
 * grid(frameBlocks, batchBlocks)
 */
template<class Op, bool isBatch>
__global__ void KeLstmForward(Op op,
                              hl_lstm_value value,
                              int frameSize,
                              int batchSize,
                              hl_activation_mode_t active_node,
                              hl_activation_mode_t active_gate,
                              hl_activation_mode_t active_state) {
  const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frameIdx >= frameSize) return;

  int batchIdx = 0;
  if (isBatch) {
    batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batchIdx >= batchSize) return;
    value.gateValue += batchIdx * frameSize * 4;
    value.outputValue += batchIdx * frameSize;
    value.stateValue  += batchIdx * frameSize;
    value.stateActiveValue += batchIdx * frameSize;
  }

  real rState;
  real rPrevState = 0;
  real rStateAtv;
  real rOut;
  real rValueIn;
  real rValueIg;
  real rValueFg;
  real rValueOg;
  real rCheckI = value.checkIg[frameIdx];
  real rCheckF = value.checkFg[frameIdx];
  real rCheckO = value.checkOg[frameIdx];

  rValueIn = value.gateValue[frameIdx];
  rValueIg = value.gateValue[frameIdx + frameSize];
  rValueFg = value.gateValue[frameIdx + frameSize * 2];
  rValueOg = value.gateValue[frameIdx + frameSize * 3];

  if (value.prevStateValue) {
    if (isBatch) value.prevStateValue += batchIdx * frameSize;
    rPrevState = value.prevStateValue[frameIdx];
  }

  op(rValueIn,
     rValueIg,
     rValueFg,
     rValueOg,
     rPrevState,
     rState,
     rStateAtv,
     rOut,
     rCheckI,
     rCheckF,
     rCheckO,
     hppl::gpu::forward[active_node],
     hppl::gpu::forward[active_gate],
     hppl::gpu::forward[active_state]);

  value.gateValue[frameIdx] = rValueIn;
  value.gateValue[frameIdx + frameSize] = rValueIg;
  value.gateValue[frameIdx + frameSize * 2] = rValueFg;
  value.gateValue[frameIdx + frameSize * 3] = rValueOg;

  value.stateValue[frameIdx] = rState;
  value.stateActiveValue[frameIdx] = rStateAtv;
  value.outputValue[frameIdx] = rOut;
}

/*
 * threads(framePerBlock, batchPerBlock)
 * grid(frameBlocks, batchBlocks)
 */
template<class Op, bool isBatch>
__global__ void KeLstmBackward(Op op,
                               hl_lstm_value value,
                               hl_lstm_grad grad,
                               int frameSize,
                               int batchSize,
                               hl_activation_mode_t active_node,
                               hl_activation_mode_t active_gate,
                               hl_activation_mode_t active_state) {
  const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frameIdx >= frameSize) return;

  int batchIdx = 0;
  if (isBatch) {
    batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batchIdx >= batchSize) return;
    value.gateValue += batchIdx * frameSize * 4;
    value.stateValue += batchIdx * frameSize;
    value.stateActiveValue += batchIdx * frameSize;
    grad.gateGrad += batchIdx * frameSize * 4;
    grad.stateGrad += batchIdx * frameSize;
    grad.outputGrad += batchIdx * frameSize;
  }

  real rValueIn;
  real rValueIg;
  real rValueFg;
  real rValueOg;
  real rGradIn;
  real rGradIg;
  real rGradFg;
  real rGradOg;
  real rPrevState = 0;
  real rPrevStateGrad;
  real rState;
  real rStateGrad;
  real rStateAtv;
  real rOutputGrad;
  real rCheckI = value.checkIg[frameIdx];
  real rCheckF = value.checkFg[frameIdx];
  real rCheckO = value.checkOg[frameIdx];
  real rCheckIGrad;
  real rCheckFGrad;
  real rCheckOGrad;

  rValueIn = value.gateValue[frameIdx];
  rValueIg = value.gateValue[frameIdx + frameSize];
  rValueFg = value.gateValue[frameIdx + frameSize * 2];
  rValueOg = value.gateValue[frameIdx + frameSize * 3];
  rState = value.stateValue[frameIdx];
  rStateAtv = value.stateActiveValue[frameIdx];
  rOutputGrad = grad.outputGrad[frameIdx];
  rStateGrad = grad.stateGrad[frameIdx];

  if (value.prevStateValue) {
    if (isBatch) value.prevStateValue += batchIdx * frameSize;
    rPrevState = value.prevStateValue[frameIdx];
  }

  op(rValueIn,
     rValueIg,
     rValueFg,
     rValueOg,
     rGradIn,
     rGradIg,
     rGradFg,
     rGradOg,
     rPrevState,
     rPrevStateGrad,
     rState,
     rStateGrad,
     rStateAtv,
     rOutputGrad,
     rCheckI,
     rCheckF,
     rCheckO,
     rCheckIGrad,
     rCheckFGrad,
     rCheckOGrad,
     hppl::gpu::backward[active_node],
     hppl::gpu::backward[active_gate],
     hppl::gpu::backward[active_state]);

  grad.gateGrad[frameIdx] = rGradIn;
  grad.gateGrad[frameIdx + frameSize    ] = rGradIg;
  grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
  grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
  grad.stateGrad[frameIdx] = rStateGrad;
  if (grad.prevStateGrad) {
    if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
    grad.prevStateGrad[frameIdx] = rPrevStateGrad;
  }

  if (isBatch) {
    if (value.prevStateValue) {
195 196
      if (grad.checkIgGrad) paddle::paddleAtomicAdd(grad.checkIgGrad+frameIdx, rCheckIGrad);
      if (grad.checkFgGrad) paddle::paddleAtomicAdd(grad.checkFgGrad+frameIdx, rCheckFGrad);
Z
zhangjinchao01 已提交
197
    }
198
    if (grad.checkOgGrad) paddle::paddleAtomicAdd(grad.checkOgGrad+frameIdx, rCheckOGrad);
Z
zhangjinchao01 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
  } else {
    if (value.prevStateValue) {
      if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
      if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
    }
    if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
  }
}

template<class Op>
void hl_gpu_lstm_forward(Op op,
                         hl_lstm_value value,
                         int frameSize,
                         int batchSize,
                         hl_activation_mode_t active_node,
                         hl_activation_mode_t active_gate,
                         hl_activation_mode_t active_state) {
  dim3 threads;
  dim3 grid;
  if (batchSize == 1) {
    int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
    int frameBlocks = (frameSize + 1024 - 1) / 1024;
    threads = dim3(framePerBlock, 1);
    grid = dim3(frameBlocks, 1);
  } else {
    /* framePerBlock = 32 batchPerBlock = 32 */
    threads = dim3(32, 32);
    grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
  }

  if (batchSize == 1) {
    KeLstmForward<Op, /* isBatch= */false>
      <<<grid, threads, 0, STREAM_DEFAULT>>>(op, value,
      frameSize, batchSize, active_node, active_gate, active_state);
  } else {
    KeLstmForward<Op, /* isBatch= */true>
      <<<grid, threads, 0, STREAM_DEFAULT>>>(op, value,
      frameSize, batchSize, active_node, active_gate, active_state);
  }

  CHECK_SYNC("hl_gpu_lstm_forward failed");
}

template<class Op>
void hl_gpu_lstm_backward(Op op,
                          hl_lstm_value value,
                          hl_lstm_grad grad,
                          int frameSize,
                          int batchSize,
                          hl_activation_mode_t active_node,
                          hl_activation_mode_t active_gate,
                          hl_activation_mode_t active_state) {
  dim3 threads;
  dim3 grid;
  if (batchSize == 1) {
    int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
    int frameBlocks = (frameSize + 1024 - 1) / 1024;
    threads = dim3(framePerBlock, 1);
    grid = dim3(frameBlocks, 1);
  } else {
    /* framePerBlock = 32 batchPerBlock = 32 */
    threads = dim3(32, 32);
    grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
  }

  if (batchSize == 1) {
    KeLstmBackward<Op, /* isBatch= */false>
      <<<grid, threads, 0, STREAM_DEFAULT>>>(op, value, grad,
      frameSize, batchSize, active_node, active_gate, active_state);
  } else {
    KeLstmBackward<Op, /* isBatch= */true>
      <<<grid, threads, 0, STREAM_DEFAULT>>>(op, value, grad,
      frameSize, batchSize, active_node, active_gate, active_state);
  }

  CHECK_SYNC("hl_gpu_lstm_backward failed");
}

#else

template<class Op>
void hl_gpu_lstm_forward(Op op,
                         hl_lstm_value value,
                         int frameSize,
                         int batchSize,
                         hl_activation_mode_t active_node,
                         hl_activation_mode_t active_gate,
                         hl_activation_mode_t active_state) {}

template<class Op>
void hl_gpu_lstm_backward(Op op,
                          hl_lstm_value value,
                          hl_lstm_grad grad,
                          int frameSize,
                          int batchSize,
                          hl_activation_mode_t active_node,
                          hl_activation_mode_t active_gate,
                          hl_activation_mode_t active_state) {}

#endif

#endif /* HL_GPU_LSTM_CUH_ */