lstm_gpu_kernel.h 9.0 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16
#include "paddle/operators/math/detail/activation_functions.h"
D
dangqingqing 已提交
17 18
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
19
#include "paddle/platform/device_context.h"
D
dangqingqing 已提交
20

21
#include <type_traits>
22

D
dangqingqing 已提交
23 24 25 26 27 28
namespace paddle {
namespace operators {
namespace math {
namespace detail {

/*
29 30
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
D
dangqingqing 已提交
31
 */
32 33
template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
34 35 36
                              int batch_size, ActivationType active_node,
                              ActivationType active_gate,
                              ActivationType active_state) {
37 38 39 40 41 42 43 44 45 46 47
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frame_idx >= frame_size) return;

  int batch_idx = 0;
  if (is_batch) {
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batch_idx >= batch_size) return;
    value.gate_value += batch_idx * frame_size * 4;
    value.output_value += batch_idx * frame_size;
    value.state_value += batch_idx * frame_size;
    value.state_active_value += batch_idx * frame_size;
D
dangqingqing 已提交
48 49
  }

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
  T r_state;
  T r_prev_state = 0;
  T r_state_atv;
  T r_out;
  T r_value_in;
  T r_value_ig;
  T r_value_fg;
  T r_value_og;

  T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0;
  T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0;
  T r_checkO = value.check_og ? value.check_og[frame_idx] : 0;

  r_value_in = value.gate_value[frame_idx];
  r_value_ig = value.gate_value[frame_idx + frame_size];
  r_value_fg = value.gate_value[frame_idx + frame_size * 2];
  r_value_og = value.gate_value[frame_idx + frame_size * 3];

  if (value.prev_state_value) {
    if (is_batch) value.prev_state_value += batch_idx * frame_size;
    r_prev_state = value.prev_state_value[frame_idx];
D
dangqingqing 已提交
71 72
  }

73 74 75
  op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
     r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate,
     active_state);
D
dangqingqing 已提交
76

77 78 79 80
  value.gate_value[frame_idx] = r_value_in;
  value.gate_value[frame_idx + frame_size] = r_value_ig;
  value.gate_value[frame_idx + frame_size * 2] = r_value_fg;
  value.gate_value[frame_idx + frame_size * 3] = r_value_og;
D
dangqingqing 已提交
81

82 83 84
  value.state_value[frame_idx] = r_state;
  value.state_active_value[frame_idx] = r_state_atv;
  value.output_value[frame_idx] = r_out;
D
dangqingqing 已提交
85 86 87
}

/*
88 89
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
D
dangqingqing 已提交
90
 */
91
template <class T, class Op, bool is_batch>
92
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
93
                               LstmMetaGrad<T> grad, int frame_size,
94 95 96
                               int batch_size, ActivationType active_node,
                               ActivationType active_gate,
                               ActivationType active_state) {
97 98 99 100 101 102 103 104 105 106 107 108 109
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (frame_idx >= frame_size) return;

  int batch_idx = 0;
  if (is_batch) {
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
    if (batch_idx >= batch_size) return;
    value.gate_value += batch_idx * frame_size * 4;
    value.state_value += batch_idx * frame_size;
    value.state_active_value += batch_idx * frame_size;
    grad.gate_grad += batch_idx * frame_size * 4;
    grad.state_grad += batch_idx * frame_size;
    grad.output_grad += batch_idx * frame_size;
D
dangqingqing 已提交
110 111
  }

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  T r_value_in;
  T r_value_ig;
  T r_value_fg;
  T r_value_og;
  T r_grad_in;
  T r_grad_ig;
  T r_grad_fg;
  T r_grad_og;
  T r_prev_state = 0;
  T r_prev_state_grad;
  T r_state;
  T r_state_grad;
  T r_state_atv;
  T r_output_grad;
  T r_checkI = value.check_ig ? value.check_ig[frame_idx] : 0;
  T r_checkF = value.check_fg ? value.check_fg[frame_idx] : 0;
  T r_checkO = value.check_og ? value.check_og[frame_idx] : 0;

  T r_checkIGrad;
  T r_checkFGrad;
  T r_checkOGrad;

  r_value_in = value.gate_value[frame_idx];
  r_value_ig = value.gate_value[frame_idx + frame_size];
  r_value_fg = value.gate_value[frame_idx + frame_size * 2];
  r_value_og = value.gate_value[frame_idx + frame_size * 3];
  r_state = value.state_value[frame_idx];
  r_state_atv = value.state_active_value[frame_idx];
  r_output_grad = grad.output_grad[frame_idx];
  r_state_grad = grad.state_grad[frame_idx];

  if (value.prev_state_value) {
    if (is_batch) value.prev_state_value += batch_idx * frame_size;
    r_prev_state = value.prev_state_value[frame_idx];
D
dangqingqing 已提交
146 147
  }

148 149 150 151 152 153 154 155 156 157 158 159 160 161
  op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
     r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
     r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
     r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
     active_state);

  grad.gate_grad[frame_idx] = r_grad_in;
  grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
  grad.gate_grad[frame_idx + frame_size * 2] = r_grad_fg;
  grad.gate_grad[frame_idx + frame_size * 3] = r_grad_og;
  grad.state_grad[frame_idx] = r_state_grad;
  if (grad.prev_state_grad) {
    if (is_batch) grad.prev_state_grad += batch_idx * frame_size;
    grad.prev_state_grad[frame_idx] = r_prev_state_grad;
D
dangqingqing 已提交
162 163
  }

164 165 166 167 168 169 170 171
  if (is_batch) {
    if (value.prev_state_value) {
      if (grad.check_ig_grad)
        paddle::platform::CudaAtomicAdd(grad.check_ig_grad + frame_idx,
                                        r_checkIGrad);
      if (grad.check_fg_grad)
        paddle::platform::CudaAtomicAdd(grad.check_fg_grad + frame_idx,
                                        r_checkFGrad);
D
dangqingqing 已提交
172
    }
173 174 175
    if (grad.check_og_grad)
      paddle::platform::CudaAtomicAdd(grad.check_og_grad + frame_idx,
                                      r_checkOGrad);
D
dangqingqing 已提交
176
  } else {
177 178 179
    if (value.prev_state_value) {
      if (grad.check_ig_grad) grad.check_ig_grad[frame_idx] += r_checkIGrad;
      if (grad.check_fg_grad) grad.check_fg_grad[frame_idx] += r_checkFGrad;
D
dangqingqing 已提交
180
    }
181
    if (grad.check_og_grad) grad.check_og_grad[frame_idx] += r_checkOGrad;
D
dangqingqing 已提交
182 183 184 185
  }
}

template <class T, class Op>
186
void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
187
                      LstmMetaValue<T> value, int frame_size, int batch_size,
D
dangqingqing 已提交
188
                      ActivationType active_node, ActivationType active_gate,
189
                      ActivationType active_state) {
D
dangqingqing 已提交
190 191
  dim3 threads;
  dim3 grid;
192 193 194 195 196
  if (batch_size == 1) {
    int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
    int frame_blocks = (frame_size + 1024 - 1) / 1024;
    threads = dim3(frame_per_block, 1);
    grid = dim3(frame_blocks, 1);
D
dangqingqing 已提交
197
  } else {
198
    /* frame_per_block = 32 batch_per_block = 32 */
D
dangqingqing 已提交
199
    threads = dim3(32, 32);
200
    grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32);
D
dangqingqing 已提交
201 202
  }

203 204
  auto stream =
      reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
205
  if (batch_size == 1) {
D
dangqingqing 已提交
206
    KeLstmForward<T, Op,
207 208
                  /* is_batch= */ false><<<grid, threads, 0, stream>>>(
        op, value, frame_size, batch_size, active_node, active_gate,
209
        active_state);
D
dangqingqing 已提交
210 211
  } else {
    KeLstmForward<T, Op,
212 213
                  /* is_batch= */ true><<<grid, threads, 0, stream>>>(
        op, value, frame_size, batch_size, active_node, active_gate,
214
        active_state);
D
dangqingqing 已提交
215 216 217 218
  }
}

template <class T, class Op>
219 220
void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
                       LstmMetaValue<T> value, LstmMetaGrad<T> grad,
221
                       int frame_size, int batch_size,
D
dangqingqing 已提交
222
                       ActivationType active_node, ActivationType active_gate,
223
                       ActivationType active_state) {
D
dangqingqing 已提交
224 225
  dim3 threads;
  dim3 grid;
226 227 228 229 230
  if (batch_size == 1) {
    int frame_per_block = frame_size <= 1024 ? frame_size : 1024;
    int frame_blocks = (frame_size + 1024 - 1) / 1024;
    threads = dim3(frame_per_block, 1);
    grid = dim3(frame_blocks, 1);
D
dangqingqing 已提交
231
  } else {
232
    /* frame_per_block = 32 batch_per_block = 16 */
233
    threads = dim3(32, 16);
234
    grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 16 - 1) / 16);
D
dangqingqing 已提交
235 236
  }

237 238
  auto stream =
      reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
239
  if (batch_size == 1) {
D
dangqingqing 已提交
240
    KeLstmBackward<T, Op,
241 242
                   /* is_batch= */ false><<<grid, threads, 0, stream>>>(
        op, value, grad, frame_size, batch_size, active_node, active_gate,
243
        active_state);
D
dangqingqing 已提交
244 245
  } else {
    KeLstmBackward<T, Op,
246 247
                   /* is_batch= */ true><<<grid, threads, 0, stream>>>(
        op, value, grad, frame_size, batch_size, active_node, active_gate,
248
        active_state);
D
dangqingqing 已提交
249 250 251 252 253 254 255
  }
}

}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle