gru_gpu_kernel.h 7.4 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <type_traits>
17
#include "paddle/operators/math/detail/activation_functions.h"
G
guosheng 已提交
18 19 20 21 22 23 24 25 26 27 28 29
#include "paddle/operators/math/gru_compute.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"

#include <glog/logging.h>

namespace paddle {
namespace operators {
namespace math {
namespace detail {

/*
G
guosheng 已提交
30 31
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
32
 */
G
guosheng 已提交
33 34 35 36 37
template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
                                        T *gate_value, T *reset_output_value,
                                        T *prev_output_value, int frame_size,
                                        int batch_size,
G
guosheng 已提交
38
                                        activation_mode_t active_gate) {
G
guosheng 已提交
39 40 41 42 43 44 45 46 47
  const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
  if (frame_idx >= frame_size) return;

  int batch_idx = 0;
  if (is_batch) {
    batch_idx = block_idx.y * block_dim.y + thread_idx.y;
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    reset_output_value += batch_idx * frame_size;
G
guosheng 已提交
48 49
  }

G
guosheng 已提交
50 51 52 53
  T r_prev_out = 0;
  T r_value_reset_output;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_reset_gate = gate_value[frame_idx + frame_size * 1];
G
guosheng 已提交
54

G
guosheng 已提交
55 56 57
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
58 59
  }

G
guosheng 已提交
60 61
  op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
                  r_value_reset_output, active_gate);
G
guosheng 已提交
62

G
guosheng 已提交
63 64 65
  gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
  gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
  reset_output_value[frame_idx] = r_value_reset_output;
G
guosheng 已提交
66 67 68
}

/*
G
guosheng 已提交
69 70
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
71
 */
G
guosheng 已提交
72 73 74 75 76
template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
                                        T *gate_value, T *prev_output_value,
                                        T *output_value, int frame_size,
                                        int batch_size,
G
guosheng 已提交
77
                                        activation_mode_t active_node) {
G
guosheng 已提交
78 79 80 81 82 83 84 85
  const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
    batch_idx = block_idx.y * block_dim.y + thread_idx.y;
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    output_value += batch_idx * frame_size;
G
guosheng 已提交
86 87
  }

G
guosheng 已提交
88 89 90 91
  T r_output;
  T r_prev_out = 0;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_frame_state = gate_value[frame_idx + frame_size * 2];
G
guosheng 已提交
92

G
guosheng 已提交
93 94 95
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
96 97
  }

G
guosheng 已提交
98 99
  op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
                  r_output, active_node);
G
guosheng 已提交
100

G
guosheng 已提交
101 102
  gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
  output_value[frame_idx] = r_output;
G
guosheng 已提交
103 104 105
}

/*
G
guosheng 已提交
106 107
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
108
 */
G
guosheng 已提交
109 110 111 112 113
template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *output_grad,
                                       int frame_size, int batch_size,
G
guosheng 已提交
114
                                       activation_mode_t active_node) {
G
guosheng 已提交
115 116 117 118 119 120 121 122 123
  const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
    batch_idx = block_idx.y * block_dim.y + thread_idx.y;
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    output_grad += batch_idx * frame_size;
G
guosheng 已提交
124 125
  }

G
guosheng 已提交
126 127 128 129 130 131 132
  T r_update_gate_grad;
  T r_frame_state_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_frame_state_value = gate_value[frame_idx + frame_size * 2];
  T r_out_grad = output_grad[frame_idx];
G
guosheng 已提交
133

G
guosheng 已提交
134 135 136
  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
G
guosheng 已提交
137

G
guosheng 已提交
138 139
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_grad = prev_out_grad[frame_idx];
G
guosheng 已提交
140 141
  }

G
guosheng 已提交
142 143 144
  op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
                r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
                r_out_grad, active_node);
G
guosheng 已提交
145

G
guosheng 已提交
146 147 148 149
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
150 151 152 153
  }
}

/*
G
guosheng 已提交
154 155
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
156
 */
G
guosheng 已提交
157 158 159 160 161
template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *reset_output_grad,
                                       int frame_size, int batch_size,
G
guosheng 已提交
162
                                       activation_mode_t active_gate) {
G
guosheng 已提交
163 164 165 166 167 168 169 170 171
  const int frame_idx = block_idx.x * block_dim.x + thread_idx.x;
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
    batch_idx = block_idx.y * block_dim.y + thread_idx.y;
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    reset_output_grad += batch_idx * frame_size;
G
guosheng 已提交
172 173
  }

G
guosheng 已提交
174 175 176 177 178 179 180 181 182 183 184 185 186 187
  T r_reset_gate_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_reset_output_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0];
  T r_reset_gate_value = gate_value[frame_idx + frame_size * 1];

  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
    r_prev_out_grad = prev_out_grad[frame_idx];
    r_reset_output_grad = reset_output_grad[frame_idx];
G
guosheng 已提交
188 189
  }

G
guosheng 已提交
190 191 192
  op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
                r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
                r_reset_output_grad, active_gate);
G
guosheng 已提交
193

G
guosheng 已提交
194 195 196 197
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
198 199 200 201 202 203
  }
}
}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle