gru_gpu_kernel.h 7.4 KB
Newer Older
G
guosheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <type_traits>
Y
Yi Wang 已提交
17 18 19 20
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/device_context.h"
G
guosheng 已提交
21 22 23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {
namespace detail {

/*
G
guosheng 已提交
28 29
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
30
 */
G
guosheng 已提交
31 32 33 34 35
template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
                                        T *gate_value, T *reset_output_value,
                                        T *prev_output_value, int frame_size,
                                        int batch_size,
36
                                        ActivationType active_gate) {
37
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
38 39 40 41
  if (frame_idx >= frame_size) return;

  int batch_idx = 0;
  if (is_batch) {
42
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
43 44 45
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    reset_output_value += batch_idx * frame_size;
G
guosheng 已提交
46 47
  }

G
guosheng 已提交
48 49 50 51
  T r_prev_out = 0;
  T r_value_reset_output;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_reset_gate = gate_value[frame_idx + frame_size * 1];
G
guosheng 已提交
52

G
guosheng 已提交
53 54 55
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
56 57
  }

G
guosheng 已提交
58 59
  op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
                  r_value_reset_output, active_gate);
G
guosheng 已提交
60

G
guosheng 已提交
61 62 63
  gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
  gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
  reset_output_value[frame_idx] = r_value_reset_output;
G
guosheng 已提交
64 65 66
}

/*
G
guosheng 已提交
67 68
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
69
 */
G
guosheng 已提交
70 71 72 73 74
template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
                                        T *gate_value, T *prev_output_value,
                                        T *output_value, int frame_size,
                                        int batch_size,
75
                                        ActivationType active_node) {
76
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
77 78 79
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
80
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
81 82 83
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    output_value += batch_idx * frame_size;
G
guosheng 已提交
84 85
  }

G
guosheng 已提交
86 87 88 89
  T r_output;
  T r_prev_out = 0;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_frame_state = gate_value[frame_idx + frame_size * 2];
G
guosheng 已提交
90

G
guosheng 已提交
91 92 93
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
94 95
  }

G
guosheng 已提交
96 97
  op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
                  r_output, active_node);
G
guosheng 已提交
98

G
guosheng 已提交
99 100
  gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
  output_value[frame_idx] = r_output;
G
guosheng 已提交
101 102 103
}

/*
G
guosheng 已提交
104 105
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
106
 */
G
guosheng 已提交
107 108 109 110 111
template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *output_grad,
                                       int frame_size, int batch_size,
112
                                       ActivationType active_node) {
113
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
114 115 116
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
117
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
118 119 120 121
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    output_grad += batch_idx * frame_size;
G
guosheng 已提交
122 123
  }

G
guosheng 已提交
124 125 126 127 128 129 130
  T r_update_gate_grad;
  T r_frame_state_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_frame_state_value = gate_value[frame_idx + frame_size * 2];
  T r_out_grad = output_grad[frame_idx];
G
guosheng 已提交
131

G
guosheng 已提交
132 133 134
  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
G
guosheng 已提交
135

G
guosheng 已提交
136 137
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_grad = prev_out_grad[frame_idx];
G
guosheng 已提交
138 139
  }

G
guosheng 已提交
140 141 142
  op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
                r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
                r_out_grad, active_node);
G
guosheng 已提交
143

G
guosheng 已提交
144 145 146 147
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
148 149 150 151
  }
}

/*
G
guosheng 已提交
152 153
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
154
 */
G
guosheng 已提交
155 156 157 158 159
template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *reset_output_grad,
                                       int frame_size, int batch_size,
160
                                       ActivationType active_gate) {
161
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
162 163 164
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
165
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
166 167 168 169
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    reset_output_grad += batch_idx * frame_size;
G
guosheng 已提交
170 171
  }

G
guosheng 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185
  T r_reset_gate_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_reset_output_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0];
  T r_reset_gate_value = gate_value[frame_idx + frame_size * 1];

  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
    r_prev_out_grad = prev_out_grad[frame_idx];
    r_reset_output_grad = reset_output_grad[frame_idx];
G
guosheng 已提交
186 187
  }

G
guosheng 已提交
188 189 190
  op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
                r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
                r_reset_output_grad, active_gate);
G
guosheng 已提交
191

G
guosheng 已提交
192 193 194 195
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
196 197 198 199 200 201
  }
}
}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle