gru_cpu_kernel.h 15.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <type_traits>
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h"
G
guosheng 已提交
19 20 21 22 23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {
namespace detail {

#ifndef __NVCC__

template <class OpResetOutput, typename T>
G
guosheng 已提交
28 29 30
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
                                       T *gate_value, T *reset_output_value,
                                       T *prev_output_value, int frame_size,
31
                                       ActivationType active_gate) {
G
guosheng 已提交
32 33 34 35 36 37 38 39 40 41 42 43
  T r_value_update_gate;
  T r_value_reset_gate;
  T r_value_reset_output;
  T r_prev_out = 0;
  T *update_gate = gate_value;
  T *reset_gate = gate_value + frame_size;

  for (int i = 0; i < frame_size; i++) {
    r_value_update_gate = update_gate[i];
    r_value_reset_gate = reset_gate[i];
    if (prev_output_value) {
      r_prev_out = prev_output_value[i];
G
guosheng 已提交
44 45
    }

46 47
    op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
                    &r_value_reset_output, active_gate);
G
guosheng 已提交
48

G
guosheng 已提交
49 50 51
    update_gate[i] = r_value_update_gate;
    reset_gate[i] = r_value_reset_gate;
    reset_output_value[i] = r_value_reset_output;
G
guosheng 已提交
52 53 54 55
  }
}

template <class OpFinalOutput, typename T>
G
guosheng 已提交
56 57 58
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
                                       T *gate_value, T *prev_output_value,
                                       T *output_value, int frame_size,
59
                                       ActivationType active_node) {
G
guosheng 已提交
60 61 62 63 64 65 66 67 68 69 70 71
  T r_value_update_gate;
  T r_value_frame_state;
  T r_prev_out = 0;
  T r_output;
  T *update_gate = gate_value;
  T *frame_state = gate_value + frame_size * 2;

  for (int i = 0; i < frame_size; i++) {
    r_value_update_gate = update_gate[i];
    r_value_frame_state = frame_state[i];
    if (prev_output_value) {
      r_prev_out = prev_output_value[i];
G
guosheng 已提交
72 73
    }

74 75
    op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
                    &r_output, active_node);
G
guosheng 已提交
76

G
guosheng 已提交
77 78
    frame_state[i] = r_value_frame_state;
    output_value[i] = r_output;
G
guosheng 已提交
79 80 81 82
  }
}

template <class OpResetOutput, typename T>
G
guosheng 已提交
83 84 85
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
                                     T *gate_value, T *reset_output_value,
                                     T *prev_output_value, int frame_size,
86
                                     ActivationType active_gate) {
G
guosheng 已提交
87
#ifdef __AVX__
G
guosheng 已提交
88 89 90 91
  __m256 r_value_update_gate;
  __m256 r_value_reset_gate;
  __m256 r_value_reset_output;
  __m256 r_prev_out = _mm256_set1_ps(0.0f);
92 93
  __m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
  __m256 *reset_gate = reinterpret_cast<__m256 *>(gate_value + frame_size);
G
guosheng 已提交
94 95 96 97 98

  for (int i = 0; i < frame_size / 8; i++) {
    r_value_update_gate = update_gate[i];
    r_value_reset_gate = reset_gate[i];
    if (prev_output_value) {
99
      r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
G
guosheng 已提交
100 101
    }

102 103
    op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
                    &r_value_reset_output, active_gate);
G
guosheng 已提交
104

G
guosheng 已提交
105 106
    update_gate[i] = r_value_update_gate;
    reset_gate[i] = r_value_reset_gate;
107
    (reinterpret_cast<__m256 *>(reset_output_value))[i] = r_value_reset_output;
G
guosheng 已提交
108 109 110 111 112
  }
#endif
}

template <class OpFinalOutput, typename T>
G
guosheng 已提交
113 114 115
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
                                     T *gate_value, T *prev_output_value,
                                     T *output_value, int frame_size,
116
                                     ActivationType active_node) {
G
guosheng 已提交
117
#ifdef __AVX__
G
guosheng 已提交
118 119 120 121
  __m256 r_value_update_gate;
  __m256 r_value_frame_state;
  __m256 r_prev_out = _mm256_set1_ps(0.0f);
  __m256 r_output;
122 123
  __m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
  __m256 *frame_state = reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
G
guosheng 已提交
124 125 126 127 128

  for (int i = 0; i < frame_size / 8; i++) {
    r_value_update_gate = update_gate[i];
    r_value_frame_state = frame_state[i];
    if (prev_output_value) {
129
      r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
G
guosheng 已提交
130 131
    }

132 133
    op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
                    &r_output, active_node);
G
guosheng 已提交
134

G
guosheng 已提交
135
    frame_state[i] = r_value_frame_state;
136
    (reinterpret_cast<__m256 *>(output_value))[i] = r_output;
G
guosheng 已提交
137 138 139 140 141
  }
#endif
}

template <class OpResetOutput, typename T>
G
guosheng 已提交
142
inline void forward_reset_output(OpResetOutput op_reset_output,
143 144
                                 GRUMetaValue<T> value, int frame_size,
                                 int batch_size, ActivationType active_gate) {
G
guosheng 已提交
145 146
  for (int b = 0; b < batch_size; b++) {
    if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
G
guosheng 已提交
147
      hl_avx_gru_forward_reset_output(
G
guosheng 已提交
148 149
          op_reset_output, value.gate_value, value.reset_output_value,
          value.prev_out_value, frame_size, active_gate);
G
guosheng 已提交
150 151
    } else {
      hl_naive_gru_forward_reset_output(
G
guosheng 已提交
152 153
          op_reset_output, value.gate_value, value.reset_output_value,
          value.prev_out_value, frame_size, active_gate);
G
guosheng 已提交
154 155
    }

G
guosheng 已提交
156 157 158 159
    value.gate_value += frame_size * 3;
    value.reset_output_value += frame_size;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
160 161 162 163 164
    }
  }
}

template <class OpFinalOutput, typename T>
G
guosheng 已提交
165
inline void forward_final_output(OpFinalOutput op_final_output,
166 167
                                 GRUMetaValue<T> value, int frame_size,
                                 int batch_size, ActivationType active_node) {
G
guosheng 已提交
168 169 170 171 172
  for (int b = 0; b < batch_size; b++) {
    if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
      hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
                                      value.prev_out_value, value.output_value,
                                      frame_size, active_node);
G
guosheng 已提交
173
    } else {
G
guosheng 已提交
174 175 176
      hl_naive_gru_forward_final_output(
          op_final_output, value.gate_value, value.prev_out_value,
          value.output_value, frame_size, active_node);
G
guosheng 已提交
177 178
    }

G
guosheng 已提交
179 180 181 182
    value.gate_value += frame_size * 3;
    value.output_value += frame_size;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
183 184 185 186 187
    }
  }
}

template <class OpStateGrad, typename T>
G
guosheng 已提交
188 189 190 191
void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
                                      T *gate_grad, T *prev_out_value,
                                      T *prev_out_grad, T *output_grad,
                                      int frame_size,
192
                                      ActivationType active_node) {
G
guosheng 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
  T r_update_gate_value;
  T r_update_gate_grad;
  T r_frame_state_value;
  T r_frame_state_grad;
  T r_out_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T *update_gate_value = gate_value;
  T *update_gate_grad = gate_grad;
  T *frame_state_value = gate_value + frame_size * 2;
  T *frame_state_grad = gate_grad + frame_size * 2;

  for (int i = 0; i < frame_size; i++) {
    r_update_gate_value = update_gate_value[i];
    r_frame_state_value = frame_state_value[i];
    r_out_grad = output_grad[i];
    if (prev_out_value) {
      r_prev_out_value = prev_out_value[i];
G
guosheng 已提交
211
    }
G
guosheng 已提交
212 213
    if (prev_out_grad) {
      r_prev_out_grad = prev_out_grad[i];
G
guosheng 已提交
214 215
    }

216 217 218
    op_state_grad(&r_update_gate_value, &r_update_gate_grad,
                  &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
                  &r_prev_out_grad, &r_out_grad, active_node);
G
guosheng 已提交
219

G
guosheng 已提交
220 221 222 223
    update_gate_grad[i] = r_update_gate_grad;
    frame_state_grad[i] = r_frame_state_grad;
    if (prev_out_grad) {
      prev_out_grad[i] = r_prev_out_grad;
G
guosheng 已提交
224 225 226 227 228
    }
  }
}

template <class OpResetGrad, typename T>
G
guosheng 已提交
229 230 231 232
void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
                                      T *gate_grad, T *prev_out_value,
                                      T *prev_out_grad, T *reset_output_grad,
                                      int frame_size,
233
                                      ActivationType active_gate) {
G
guosheng 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
  T r_update_gate_value;
  T r_update_gate_grad;
  T r_reset_gate_value;
  T r_reset_gate_grad;
  T r_reset_output_grad = 0;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T *update_gate_value = gate_value;
  T *update_gate_grad = gate_grad;
  T *reset_gate_value = gate_value + frame_size;
  T *reset_gate_grad = gate_grad + frame_size;

  for (int i = 0; i < frame_size; i++) {
    r_update_gate_value = update_gate_value[i];
    r_update_gate_grad = update_gate_grad[i];
    r_reset_gate_value = reset_gate_value[i];

    if (prev_out_value && prev_out_grad) {
      r_reset_output_grad = reset_output_grad[i];
G
guosheng 已提交
253
    }
G
guosheng 已提交
254 255
    if (prev_out_value) {
      r_prev_out_value = prev_out_value[i];
G
guosheng 已提交
256
    }
G
guosheng 已提交
257 258
    if (prev_out_grad) {
      r_prev_out_grad = prev_out_grad[i];
G
guosheng 已提交
259 260
    }

261 262 263
    op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
                  &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
                  &r_prev_out_grad, &r_reset_output_grad, active_gate);
G
guosheng 已提交
264

G
guosheng 已提交
265 266 267 268
    update_gate_grad[i] = r_update_gate_grad;
    reset_gate_grad[i] = r_reset_gate_grad;
    if (prev_out_grad) {
      prev_out_grad[i] = r_prev_out_grad;
G
guosheng 已提交
269 270 271 272 273
    }
  }
}

template <class OpStateGrad, typename T>
G
guosheng 已提交
274 275 276 277
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
                                    T *gate_grad, T *prev_out_value,
                                    T *prev_out_grad, T *output_grad,
                                    int frame_size,
278
                                    ActivationType active_node) {
G
guosheng 已提交
279
#ifdef __AVX__
G
guosheng 已提交
280 281 282 283 284 285 286
  __m256 r_update_gate_value;
  __m256 r_update_gate_grad;
  __m256 r_frame_state_value;
  __m256 r_frame_state_grad;
  __m256 r_out_grad;
  __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
287 288 289 290 291 292
  __m256 *update_gate_value = reinterpret_cast<__m256 *>(gate_value);
  __m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
  __m256 *frame_state_value =
      reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
  __m256 *frame_state_grad =
      reinterpret_cast<__m256 *>(gate_grad + frame_size * 2);
G
guosheng 已提交
293 294 295 296

  for (int i = 0; i < frame_size / 8; i++) {
    r_update_gate_value = update_gate_value[i];
    r_frame_state_value = frame_state_value[i];
297
    r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i];
G
guosheng 已提交
298
    if (prev_out_value) {
299
      r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i];
G
guosheng 已提交
300
    }
G
guosheng 已提交
301
    if (prev_out_grad) {
302
      r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
G
guosheng 已提交
303 304
    }

305 306 307
    op_state_grad(&r_update_gate_value, &r_update_gate_grad,
                  &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
                  &r_prev_out_grad, &r_out_grad, active_node);
G
guosheng 已提交
308

G
guosheng 已提交
309 310 311
    update_gate_grad[i] = r_update_gate_grad;
    frame_state_grad[i] = r_frame_state_grad;
    if (prev_out_grad) {
312
      (reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_prev_out_grad;
G
guosheng 已提交
313 314 315 316 317 318
    }
  }
#endif
}

template <class OpResetGrad, typename T>
G
guosheng 已提交
319 320 321 322
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
                                    T *gate_grad, T *prev_out_value,
                                    T *prev_out_grad, T *reset_output_grad,
                                    int frame_size,
323
                                    ActivationType active_gate) {
G
guosheng 已提交
324
#ifdef __AVX__
G
guosheng 已提交
325 326 327 328 329 330 331
  __m256 r_update_gate_value;
  __m256 r_update_gate_grad;
  __m256 r_reset_gate_value;
  __m256 r_reset_gate_grad;
  __m256 r_reset_output_grad = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
332 333 334 335 336
  __m256 *update_gate_value = reinterpret_cast<__m256 *>(gate_value);
  __m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
  __m256 *reset_gate_value =
      reinterpret_cast<__m256 *>(gate_value + frame_size);
  __m256 *reset_gate_grad = reinterpret_cast<__m256 *>(gate_grad + frame_size);
G
guosheng 已提交
337 338 339 340 341 342 343

  for (int i = 0; i < frame_size / 8; i++) {
    r_update_gate_value = update_gate_value[i];
    r_update_gate_grad = update_gate_grad[i];
    r_reset_gate_value = reset_gate_value[i];

    if (prev_out_value && prev_out_grad) {
344
      r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
G
guosheng 已提交
345
    }
G
guosheng 已提交
346
    if (prev_out_value) {
347
      r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i];
G
guosheng 已提交
348
    }
G
guosheng 已提交
349
    if (prev_out_grad) {
350
      r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
G
guosheng 已提交
351 352
    }

353 354 355
    op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
                  &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
                  &r_prev_out_grad, &r_reset_output_grad, active_gate);
G
guosheng 已提交
356

G
guosheng 已提交
357 358 359
    update_gate_grad[i] = r_update_gate_grad;
    reset_gate_grad[i] = r_reset_gate_grad;
    if (prev_out_grad) {
360
      (reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_prev_out_grad;
G
guosheng 已提交
361 362 363 364 365 366
    }
  }
#endif
}

template <class OpStateGrad, typename T>
G
guosheng 已提交
367
inline void backward_state_grad(OpStateGrad op_state_grad,
368
                                GRUMetaValue<T> value, GRUMetaGrad<T> grad,
G
guosheng 已提交
369
                                int frame_size, int batch_size,
370
                                ActivationType active_node) {
G
guosheng 已提交
371 372
  for (int b = 0; b < batch_size; b++) {
    if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
G
guosheng 已提交
373
      hl_avx_gru_backward_state_grad(
G
guosheng 已提交
374 375
          op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.output_grad, frame_size, active_node);
G
guosheng 已提交
376 377
    } else {
      hl_naive_gru_backward_state_grad(
G
guosheng 已提交
378 379
          op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.output_grad, frame_size, active_node);
G
guosheng 已提交
380 381
    }

G
guosheng 已提交
382 383 384
    value.gate_value += frame_size * 3;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
385 386
    }

G
guosheng 已提交
387 388 389 390
    grad.gate_grad += frame_size * 3;
    grad.output_grad += frame_size;
    if (grad.prev_out_grad) {
      grad.prev_out_grad += frame_size;
G
guosheng 已提交
391 392 393 394 395
    }
  }
}

template <class OpResetGrad, typename T>
G
guosheng 已提交
396
inline void backward_reset_grad(OpResetGrad op_reset_grad,
397
                                GRUMetaValue<T> value, GRUMetaGrad<T> grad,
G
guosheng 已提交
398
                                int frame_size, int batch_size,
399
                                ActivationType active_gate) {
G
guosheng 已提交
400 401
  for (int b = 0; b < batch_size; b++) {
    if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
G
guosheng 已提交
402
      hl_avx_gru_backward_reset_grad(
G
guosheng 已提交
403 404
          op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
G
guosheng 已提交
405 406
    } else {
      hl_naive_gru_backward_reset_grad(
G
guosheng 已提交
407 408
          op_reset_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
          grad.prev_out_grad, grad.reset_output_grad, frame_size, active_gate);
G
guosheng 已提交
409 410
    }

G
guosheng 已提交
411 412 413
    value.gate_value += frame_size * 3;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
414 415
    }

G
guosheng 已提交
416 417 418 419
    grad.gate_grad += frame_size * 3;
    grad.reset_output_grad += frame_size;
    if (grad.prev_out_grad) {
      grad.prev_out_grad += frame_size;
G
guosheng 已提交
420 421 422 423 424 425 426 427 428 429
    }
  }
}

#endif

}  // namespace detail
}  // namespace math
}  // namespace operators
}  // namespace paddle