gru_gpu_kernel.h 11.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <type_traits>
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/device_context.h"
G
guosheng 已提交
21 22 23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {
namespace detail {

/*
G
guosheng 已提交
28 29
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
30
 */
G
guosheng 已提交
31 32 33 34 35
template <class OpResetOutput, bool is_batch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
                                        T *gate_value, T *reset_output_value,
                                        T *prev_output_value, int frame_size,
                                        int batch_size,
36
                                        ActivationType active_gate) {
37
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
38 39 40 41
  if (frame_idx >= frame_size) return;

  int batch_idx = 0;
  if (is_batch) {
42
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
43 44 45
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    reset_output_value += batch_idx * frame_size;
G
guosheng 已提交
46 47
  }

G
guosheng 已提交
48 49 50 51
  T r_prev_out = 0;
  T r_value_reset_output;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_reset_gate = gate_value[frame_idx + frame_size * 1];
G
guosheng 已提交
52

G
guosheng 已提交
53 54 55
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
56 57
  }

58 59
  op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
                  &r_value_reset_output, active_gate);
G
guosheng 已提交
60

G
guosheng 已提交
61 62 63
  gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
  gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
  reset_output_value[frame_idx] = r_value_reset_output;
G
guosheng 已提交
64 65 66
}

/*
G
guosheng 已提交
67 68
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
69
 */
G
guosheng 已提交
70 71 72 73 74
template <class OpFinalOutput, bool is_batch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
                                        T *gate_value, T *prev_output_value,
                                        T *output_value, int frame_size,
                                        int batch_size,
Q
Qiao Longfei 已提交
75 76
                                        ActivationType active_node,
                                        bool origin_mode) {
77
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
78 79 80
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
81
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
82 83 84
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    output_value += batch_idx * frame_size;
G
guosheng 已提交
85 86
  }

G
guosheng 已提交
87 88 89 90
  T r_output;
  T r_prev_out = 0;
  T r_value_update_gate = gate_value[frame_idx + frame_size * 0];
  T r_value_frame_state = gate_value[frame_idx + frame_size * 2];
G
guosheng 已提交
91

G
guosheng 已提交
92 93 94
  if (prev_output_value) {
    if (is_batch) prev_output_value += batch_idx * frame_size;
    r_prev_out = prev_output_value[frame_idx];
G
guosheng 已提交
95 96
  }

97
  op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
Q
Qiao Longfei 已提交
98
                  &r_output, active_node, origin_mode);
G
guosheng 已提交
99

G
guosheng 已提交
100 101
  gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
  output_value[frame_idx] = r_output;
G
guosheng 已提交
102 103
}

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
/*
 * threads(tile_size, 1)
 * grid(frame_blocks, 1)
 */
template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value,
                                        T *gate_weight, T *reset_output,
                                        int frame_size,
                                        ActivationType active_node) {
  T xt_0 = 0.0f;
  T a0 = 0.0f;
  T c0 = 0.0f;
  T b0[Tiled_size];

  int COL = blockIdx.x * blockDim.x + threadIdx.x;
  int Tiled_mask = ((1 << Tiled_size) - 1);
  // Tiled  matrix multiply using register shift, faster than sm.
  if (prev_output_value) {
    for (int k = 0; k < (((frame_size - 1) / Tiled_size) + 1); ++k) {
      a0 = 0;
      if ((threadIdx.x + k * Tiled_size) < frame_size) {
        a0 = prev_output_value[threadIdx.x + (k * Tiled_size)];
      }
      for (int i = 0; i < Tiled_size; i++) {
        if (COL < frame_size * 2 && (i + k * Tiled_size) < frame_size) {
          b0[i] = gate_weight[(i + k * Tiled_size) * frame_size * 2 + COL];
        }
      }

      for (int i = 0; i < Tiled_size; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
        c0 = c0 + __shfl_sync(Tiled_mask, a0, i, Tiled_size) * b0[i];
#else
        c0 = c0 + __shfl(a0, i, Tiled_size) * b0[i];
#endif
      }
    }
  }

  __syncthreads();

  if (COL < frame_size * 2) {
    xt_0 = gate_value[COL];
    c0 += xt_0;
    c0 = forward::activation(c0, active_node);
    gate_value[COL] = c0;
    if (frame_size <= COL && COL < frame_size * 2) {
      T htp_0 = 0.0;
      if (prev_output_value) {
        htp_0 = prev_output_value[COL - frame_size];
      }
      reset_output[COL - frame_size] = c0 * htp_0;
    } else if (COL < frame_size) {
      gate_value[COL] = c0;
    }
  }
}

/*
 * threads(tile_size, 1)
 * grid(frame_blocks, 1)
 */
template <class T, int Tiled_size>
__global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value,
                                       T *output_value, T *gate_value,
                                       T *reset_value, int frame_size,
                                       ActivationType act_node,
                                       bool origin_mode) {
  int COL = blockIdx.x * blockDim.x + threadIdx.x;

  T a0 = 0.0f;
  T b0[Tiled_size];
  T c0 = 0.0f;

  int Tiled_mask = ((1 << Tiled_size) - 1);
  //- Tiled  matrix multiply with register shift
  if (prev_out_value) {
    for (int k = 0; k < (((frame_size - 1) / Tiled_size) + 1); ++k) {
      a0 = 0;
      if ((threadIdx.x + k * Tiled_size) < frame_size) {
        a0 = reset_value[threadIdx.x + (k * Tiled_size)];
      }
      for (int i = 0; i < Tiled_size; i++) {
        if (COL < frame_size && (i + k * Tiled_size) < frame_size) {
          b0[i] = gate_weight[(i + k * Tiled_size) * frame_size + COL];
        }
      }

      for (int i = 0; i < Tiled_size; ++i) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
        c0 = c0 + __shfl_sync(Tiled_mask, a0, i, Tiled_size) * b0[i];
#else
        c0 = c0 + __shfl(a0, i, Tiled_size) * b0[i];
#endif
      }
    }
  }

  __syncthreads();

  if (COL < frame_size) {
    T xt_0 = gate_value[COL + 2 * frame_size];
    T gta_0 = gate_value[COL];
    T htp_0 = 0;
    if (prev_out_value) htp_0 = prev_out_value[COL];
    c0 += xt_0;
    c0 = forward::activation(c0, act_node);
    gate_value[COL + 2 * frame_size] = c0;
    if (origin_mode) {
      output_value[COL] = htp_0 * gta_0 + (1 - gta_0) * c0;
    } else {
      output_value[COL] = c0 * gta_0 + (1 - gta_0) * htp_0;
    }
  }
}

G
guosheng 已提交
220
/*
G
guosheng 已提交
221 222
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
223
 */
G
guosheng 已提交
224 225 226 227 228
template <class OpStateGrad, bool is_batch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *output_grad,
                                       int frame_size, int batch_size,
229 230
                                       ActivationType active_node,
                                       bool origin_mode) {
231
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
232 233 234
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
235
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
236 237 238 239
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    output_grad += batch_idx * frame_size;
G
guosheng 已提交
240 241
  }

G
guosheng 已提交
242 243 244 245 246 247 248
  T r_update_gate_grad;
  T r_frame_state_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_frame_state_value = gate_value[frame_idx + frame_size * 2];
  T r_out_grad = output_grad[frame_idx];
G
guosheng 已提交
249

G
guosheng 已提交
250 251 252
  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
G
guosheng 已提交
253

G
guosheng 已提交
254 255
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_grad = prev_out_grad[frame_idx];
G
guosheng 已提交
256 257
  }

258 259
  op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
                &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
260
                &r_out_grad, active_node, origin_mode);
G
guosheng 已提交
261

G
guosheng 已提交
262 263 264 265
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
266 267 268 269
  }
}

/*
G
guosheng 已提交
270 271
 * threads(frame_per_block, batch_per_block)
 * grid(frame_blocks, batch_blocks)
G
guosheng 已提交
272
 */
G
guosheng 已提交
273 274 275 276 277
template <class OpResetGrad, bool is_batch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
                                       T *gate_grad, T *prev_out_value,
                                       T *prev_out_grad, T *reset_output_grad,
                                       int frame_size, int batch_size,
Q
Qiao Longfei 已提交
278
                                       ActivationType active_gate) {
279
  const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
G
guosheng 已提交
280 281 282
  if (frame_idx >= frame_size) return;
  int batch_idx = 0;
  if (is_batch) {
283
    batch_idx = blockIdx.y * blockDim.y + threadIdx.y;
G
guosheng 已提交
284 285 286 287
    if (batch_idx >= batch_size) return;
    gate_value += batch_idx * 3 * frame_size;
    gate_grad += batch_idx * 3 * frame_size;
    reset_output_grad += batch_idx * frame_size;
G
guosheng 已提交
288 289
  }

G
guosheng 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303
  T r_reset_gate_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T r_reset_output_grad = 0;
  T r_update_gate_value = gate_value[frame_idx + frame_size * 0];
  T r_update_gate_grad = gate_grad[frame_idx + frame_size * 0];
  T r_reset_gate_value = gate_value[frame_idx + frame_size * 1];

  if (prev_out_value && prev_out_grad) {
    if (is_batch) prev_out_value += batch_idx * frame_size;
    if (is_batch) prev_out_grad += batch_idx * frame_size;
    r_prev_out_value = prev_out_value[frame_idx];
    r_prev_out_grad = prev_out_grad[frame_idx];
    r_reset_output_grad = reset_output_grad[frame_idx];
G
guosheng 已提交
304 305
  }

306 307
  op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
                &r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
Q
Qiao Longfei 已提交
308
                &r_reset_output_grad, active_gate);
G
guosheng 已提交
309

G
guosheng 已提交
310 311 312 313
  gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
  gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
  if (prev_out_grad) {
    prev_out_grad[frame_idx] = r_prev_out_grad;
G
guosheng 已提交
314 315 316 317 318 319
  }
}
}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle