gru_gpu_kernel.h 7.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <type_traits>
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/device_context.h"
G
guosheng 已提交
21 22 23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {
namespace detail {

/*
G
guosheng 已提交
28 29
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
30
 */
G
guosheng 已提交
31 32 33 34 35
template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
                                        T *gate_value, T *reset_output_value,
                                        T *prev_output_value, int frame_size,
                                        int batch_size,
36
                                        ActivationType active_gate) {
37
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
38 39 40 41
  if (frame_idx >= frame_size) return;

  int batch_idx = 0;
  if (is_batch) {
42
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
43 44 45
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    reset_output_value += batch_idx * frame_size;
G
guosheng 已提交
46 47
  }

G
guosheng 已提交
48 49 50 51
  T r_prev_out = 0;
  T r_value_reset_output;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_reset_gate = gate_value[frame_idx + frame_size * 1];
G
guosheng 已提交
52

G
guosheng 已提交
53 54 55
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
56 57
  }

58 59
  op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
                  &r_value_reset_output, active_gate);
G
guosheng 已提交
60

G
guosheng 已提交
61 62 63
  gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
  gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
  reset_output_value[frame_idx] = r_value_reset_output;
G
guosheng 已提交
64 65 66
}

/*
G
guosheng 已提交
67 68
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
69
 */
G
guosheng 已提交
70 71 72 73 74
template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
                                        T *gate_value, T *prev_output_value,
                                        T *output_value, int frame_size,
                                        int batch_size,
Q
Qiao Longfei 已提交
75 76
                                        ActivationType active_node,
                                        bool origin_mode) {
77
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
78 79 80
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
81
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
82 83 84
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    output_value += batch_idx * frame_size;
G
guosheng 已提交
85 86
  }

G
guosheng 已提交
87 88 89 90
  T r_output;
  T r_prev_out = 0;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_frame_state = gate_value[frame_idx + frame_size * 2];
G
guosheng 已提交
91

G
guosheng 已提交
92 93 94
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
95 96
  }

97
  op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
Q
Qiao Longfei 已提交
98
                  &r_output, active_node, origin_mode);
G
guosheng 已提交
99

G
guosheng 已提交
100 101
  gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
  output_value[frame_idx] = r_output;
G
guosheng 已提交
102 103 104
}

/*
G
guosheng 已提交
105 106
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
107
 */
G
guosheng 已提交
108 109 110 111 112
template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *output_grad,
                                       int frame_size, int batch_size,
113 114
                                       ActivationType active_node,
                                       bool origin_mode) {
115
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
116 117 118
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
119
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
120 121 122 123
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    output_grad += batch_idx * frame_size;
G
guosheng 已提交
124 125
  }

G
guosheng 已提交
126 127 128 129 130 131 132
  T r_update_gate_grad;
  T r_frame_state_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_frame_state_value = gate_value[frame_idx + frame_size * 2];
  T r_out_grad = output_grad[frame_idx];
G
guosheng 已提交
133

G
guosheng 已提交
134 135 136
  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
G
guosheng 已提交
137

G
guosheng 已提交
138 139
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_grad = prev_out_grad[frame_idx];
G
guosheng 已提交
140 141
  }

142 143
  op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
                &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
144
                &r_out_grad, active_node, origin_mode);
G
guosheng 已提交
145

G
guosheng 已提交
146 147 148 149
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
150 151 152 153
  }
}

/*
G
guosheng 已提交
154 155
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
156
 */
G
guosheng 已提交
157 158 159 160 161
template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *reset_output_grad,
                                       int frame_size, int batch_size,
Q
Qiao Longfei 已提交
162 163
                                       ActivationType active_gate,
                                       bool origin_mode) {
164
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
165 166 167
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
168
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
169 170 171 172
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    reset_output_grad += batch_idx * frame_size;
G
guosheng 已提交
173 174
  }

G
guosheng 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188
  T r_reset_gate_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_reset_output_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0];
  T r_reset_gate_value = gate_value[frame_idx + frame_size * 1];

  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
    r_prev_out_grad = prev_out_grad[frame_idx];
    r_reset_output_grad = reset_output_grad[frame_idx];
G
guosheng 已提交
189 190
  }

191 192
  op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
                &r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
Q
Qiao Longfei 已提交
193
                &r_reset_output_grad, active_gate, origin_mode);
G
guosheng 已提交
194

G
guosheng 已提交
195 196 197 198
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
199 200 201 202 203 204
  }
}
}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle