gru_cpu_kernel.h 35.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
guosheng 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <type_traits>
17

18
#include "paddle/fluid/framework/eigen.h"
19
#include "paddle/phi/kernels/funcs/activation_functor.h"
F
Feiyu Chan 已提交
20 21
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
G
guosheng 已提交
22

F
Feiyu Chan 已提交
23 24
namespace phi {
namespace funcs {
G
guosheng 已提交
25
namespace detail {
26
using Array1 = Eigen::DSizes<int64_t, 1>;
F
Feiyu Chan 已提交
27 28
template <typename T,
          int MajorType = Eigen::RowMajor,
29
          typename IndexType = Eigen::DenseIndex>
F
Feiyu Chan 已提交
30
using EigenVector = paddle::framework::EigenVector<T, MajorType, IndexType>;
G
guosheng 已提交
31

32
#if !defined(__NVCC__) && !defined(__HIPCC___)  // @{ Group for GRU CPU
G
guosheng 已提交
33
template <class OpResetOutput, typename T>
F
Feiyu Chan 已提交
34 35 36 37 38 39 40 41
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
                                       T *gate_value,
                                       T *reset_output_value,
                                       const T *prev_output_value,
                                       int frame_size,
                                       ActivationType active_gate,
                                       bool old_version = true,
                                       const T *reset_bias = nullptr) {
G
guosheng 已提交
42 43 44 45
  T r_value_update_gate;
  T r_value_reset_gate;
  T r_value_reset_output;
  T r_prev_out = 0;
46 47 48 49 50 51 52 53 54 55
  T r_reset_bias = 0;
  T *update_gate = nullptr;
  T *reset_gate = nullptr;
  if (old_version) {
    update_gate = gate_value;
    reset_gate = gate_value + frame_size;
  } else {
    reset_gate = gate_value;
    update_gate = gate_value + frame_size;
  }
G
guosheng 已提交
56 57 58
  for (int i = 0; i < frame_size; i++) {
    r_value_update_gate = update_gate[i];
    r_value_reset_gate = reset_gate[i];
59 60 61 62
    if (!old_version) {
      r_value_reset_output = reset_output_value[i];
      r_reset_bias = reset_bias[i];
    }
G
guosheng 已提交
63 64
    if (prev_output_value) {
      r_prev_out = prev_output_value[i];
G
guosheng 已提交
65 66
    }

F
Feiyu Chan 已提交
67 68 69 70 71 72
    op_reset_output(&r_value_update_gate,
                    &r_value_reset_gate,
                    &r_prev_out,
                    &r_value_reset_output,
                    active_gate,
                    &r_reset_bias,
73
                    old_version);
G
guosheng 已提交
74

G
guosheng 已提交
75 76 77
    update_gate[i] = r_value_update_gate;
    reset_gate[i] = r_value_reset_gate;
    reset_output_value[i] = r_value_reset_output;
G
guosheng 已提交
78 79 80 81
  }
}

template <class OpFinalOutput, typename T>
F
Feiyu Chan 已提交
82 83 84 85 86 87 88 89
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
                                       T *gate_value,
                                       const T *prev_output_value,
                                       T *output_value,
                                       int frame_size,
                                       ActivationType active_node,
                                       bool origin_mode,
                                       bool old_version = true) {
G
guosheng 已提交
90 91 92 93
  T r_value_update_gate;
  T r_value_frame_state;
  T r_prev_out = 0;
  T r_output;
94 95 96 97 98 99
  T *update_gate;
  if (old_version) {
    update_gate = gate_value;
  } else {
    update_gate = gate_value + frame_size;
  }
G
guosheng 已提交
100 101 102 103 104 105 106
  T *frame_state = gate_value + frame_size * 2;

  for (int i = 0; i < frame_size; i++) {
    r_value_update_gate = update_gate[i];
    r_value_frame_state = frame_state[i];
    if (prev_output_value) {
      r_prev_out = prev_output_value[i];
G
guosheng 已提交
107 108
    }

F
Feiyu Chan 已提交
109 110 111 112 113 114
    op_final_output(&r_value_update_gate,
                    &r_value_frame_state,
                    &r_prev_out,
                    &r_output,
                    active_node,
                    origin_mode);
G
guosheng 已提交
115

G
guosheng 已提交
116 117
    frame_state[i] = r_value_frame_state;
    output_value[i] = r_output;
G
guosheng 已提交
118 119 120 121
  }
}

template <class OpResetOutput, typename T>
G
guosheng 已提交
122
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
F
Feiyu Chan 已提交
123 124 125 126
                                     T *gate_value,
                                     T *reset_output_value,
                                     const T *prev_output_value,
                                     int frame_size,
127 128 129
                                     ActivationType active_gate,
                                     bool old_version = true,
                                     const T *reset_bias = nullptr) {
G
guosheng 已提交
130
#ifdef __AVX__
131 132
  __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
  __m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f);
G
guosheng 已提交
133
  __m256 r_value_reset_output;
134 135
  __m256 r_prev_out = _mm256_set1_ps(0.0f),
         r_prev_out_last = _mm256_set1_ps(0.0f);
136 137 138 139 140 141 142 143 144 145
  __m256 r_reset_bias = _mm256_set1_ps(0.0f);
  T *update_gate;
  T *reset_gate;
  if (old_version) {
    update_gate = gate_value;
    reset_gate = gate_value + frame_size;
  } else {
    reset_gate = gate_value;
    update_gate = gate_value + frame_size;
  }
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
  int block = 8;
  const int n = frame_size;
  const int rest = n % block;
  const int end = n - rest;
  int i = 0;

  if (rest > 0) {
    i = n - block;
    r_value_update_gate_last =
        _mm256_loadu_ps((const float *)(update_gate + i));
    r_value_reset_gate_last = _mm256_loadu_ps((const float *)(reset_gate + i));
    if (prev_output_value) {
      r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
    }
  }
G
guosheng 已提交
161

162 163 164
  for (i = 0; i < end; i += block) {
    r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
    r_value_reset_gate = _mm256_loadu_ps((const float *)(reset_gate + i));
G
guosheng 已提交
165
    if (prev_output_value) {
166
      r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
G
guosheng 已提交
167
    }
168 169 170 171 172
    if (!old_version) {
      r_reset_bias = _mm256_loadu_ps((const float *)(reset_bias + i));
      r_value_reset_output =
          _mm256_loadu_ps((const float *)(reset_output_value + i));
    }
G
guosheng 已提交
173

F
Feiyu Chan 已提交
174 175 176 177 178 179
    op_reset_output(&r_value_update_gate,
                    &r_value_reset_gate,
                    &r_prev_out,
                    &r_value_reset_output,
                    active_gate,
                    &r_reset_bias,
180
                    old_version);
G
guosheng 已提交
181

182 183 184 185 186 187 188 189 190 191 192
    _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
                     r_value_update_gate);
    _mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
                     r_value_reset_gate);
    _mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
                     r_value_reset_output);
  }

  if (rest > 0) {
    i = n - block;

F
Feiyu Chan 已提交
193 194 195 196 197 198 199
    op_reset_output(&r_value_update_gate_last,
                    &r_value_reset_gate_last,
                    &r_prev_out_last,
                    &r_value_reset_output,
                    active_gate,
                    &r_reset_bias,
                    old_version);
200 201 202 203 204 205 206

    _mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
                     r_value_update_gate_last);
    _mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
                     r_value_reset_gate_last);
    _mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
                     r_value_reset_output);
G
guosheng 已提交
207 208 209 210 211
  }
#endif
}

template <class OpFinalOutput, typename T>
G
guosheng 已提交
212
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
F
Feiyu Chan 已提交
213 214 215 216
                                     T *gate_value,
                                     const T *prev_output_value,
                                     T *output_value,
                                     int frame_size,
Q
Qiao Longfei 已提交
217
                                     ActivationType active_node,
218 219
                                     bool origin_mode,
                                     bool old_version = true) {
G
guosheng 已提交
220
#ifdef __AVX__
221 222 223 224
  __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
  __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
  __m256 r_prev_out = _mm256_set1_ps(0.0f),
         r_prev_out_last = _mm256_set1_ps(0.0f);
G
guosheng 已提交
225
  __m256 r_output;
226 227 228 229 230 231 232
  T *update_gate;
  if (old_version) {
    update_gate = gate_value;
  } else {
    update_gate = gate_value + frame_size;
  }

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
  T *frame_state = gate_value + frame_size * 2;
  int block = 8;
  const int n = frame_size;
  const int rest = n % block;
  const int end = n - rest;
  int i = 0;

  if (rest > 0) {
    i = n - block;
    r_value_update_gate_last =
        _mm256_loadu_ps((const float *)(update_gate + i));
    r_value_frame_state_last =
        _mm256_loadu_ps((const float *)(frame_state + i));
    if (prev_output_value) {
      r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
    }
  }
G
guosheng 已提交
250

251 252 253
  for (i = 0; i < end; i += block) {
    r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
    r_value_frame_state = _mm256_loadu_ps((const float *)(frame_state + i));
G
guosheng 已提交
254
    if (prev_output_value) {
255
      r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
G
guosheng 已提交
256 257
    }

F
Feiyu Chan 已提交
258 259 260 261 262 263
    op_final_output(&r_value_update_gate,
                    &r_value_frame_state,
                    &r_prev_out,
                    &r_output,
                    active_node,
                    origin_mode);
G
guosheng 已提交
264

265 266 267 268 269 270 271
    _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
                     r_value_frame_state);
    _mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
  }

  if (rest > 0) {
    i = n - block;
F
Feiyu Chan 已提交
272 273 274 275 276 277
    op_final_output(&r_value_update_gate_last,
                    &r_value_frame_state_last,
                    &r_prev_out_last,
                    &r_output,
                    active_node,
                    origin_mode);
278 279 280 281

    _mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
                     r_value_frame_state_last);
    _mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
G
guosheng 已提交
282
  }
283

G
guosheng 已提交
284 285 286
#endif
}

287 288 289 290
template <typename T, typename Context>
inline void forward_reset_outputV2(const Context &context,
                                   phi::funcs::GRUMetaValue<T> value,
                                   int frame_size) {
291 292 293 294 295 296 297 298 299
  auto &place = *context.eigen_device();
  auto value_reset_gate =
      typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
  auto value_update_gate = typename EigenVector<T>::Type(
      value.gate_value + frame_size, Array1(frame_size));
  auto value_reset_output = typename EigenVector<T>::Type(
      value.reset_output_value, Array1(frame_size));
  auto value_reset_bias =
      typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size));
300 301
  SigmoidFunctor<T>()(place, value_reset_gate, value_reset_gate);
  SigmoidFunctor<T>()(place, value_update_gate, value_update_gate);
302 303 304 305
  value_reset_output.device(place) =
      (value_reset_output + value_reset_bias) * value_reset_gate;
}

306 307 308 309 310 311 312 313
template <typename Context, class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output,
                                 phi::funcs::GRUMetaValue<T> value,
                                 int frame_size,
                                 int batch_size,
                                 ActivationType active_gate,
                                 bool old_version = true,
                                 const Context *context = nullptr) {
G
guosheng 已提交
314
  for (int b = 0; b < batch_size; b++) {
315 316 317
    if (!old_version) {
      // use eigen
      forward_reset_outputV2(*context, value, frame_size);
G
guosheng 已提交
318
    } else {
319
      if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
320
          (sizeof(T) == 4)) {
F
Feiyu Chan 已提交
321 322 323 324 325 326 327 328
        hl_avx_gru_forward_reset_output(op_reset_output,
                                        value.gate_value,
                                        value.reset_output_value,
                                        value.prev_out_value,
                                        frame_size,
                                        active_gate,
                                        old_version,
                                        value.reset_bias);
329
      } else {
F
Feiyu Chan 已提交
330 331 332 333 334 335 336 337
        hl_naive_gru_forward_reset_output(op_reset_output,
                                          value.gate_value,
                                          value.reset_output_value,
                                          value.prev_out_value,
                                          frame_size,
                                          active_gate,
                                          old_version,
                                          value.reset_bias);
338
      }
G
guosheng 已提交
339
    }
G
guosheng 已提交
340 341 342 343
    value.gate_value += frame_size * 3;
    value.reset_output_value += frame_size;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
344 345 346 347
    }
  }
}

348 349 350 351
template <typename T, typename Context>
inline void forward_final_outputV2(const Context &context,
                                   phi::funcs::GRUMetaValue<T> value,
                                   int frame_size) {
352 353 354 355 356 357 358
  auto &place = *context.eigen_device();
  auto value_update_gate = typename EigenVector<T>::Type(
      value.gate_value + frame_size, Array1(frame_size));
  auto value_frame_state = typename EigenVector<T>::Type(
      value.gate_value + 2 * frame_size, Array1(frame_size));
  auto value_output =
      typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
359
  TanhFunctor<T>()(place, value_frame_state, value_frame_state);
360 361 362 363 364 365 366 367 368 369
  value_output.device(place) =
      (static_cast<T>(1.0) - value_update_gate) * value_frame_state;
  if (value.prev_out_value) {
    auto value_prev_out = typename EigenVector<T>::ConstType(
        value.prev_out_value, Array1(frame_size));
    value_output.device(place) =
        value_output + value_update_gate * value_prev_out;
  }
}

370 371 372 373 374 375 376 377 378
template <typename Context, class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
                                 phi::funcs::GRUMetaValue<T> value,
                                 int frame_size,
                                 int batch_size,
                                 ActivationType active_node,
                                 bool origin_mode,
                                 bool old_version = true,
                                 const Context *context = nullptr) {
G
guosheng 已提交
379
  for (int b = 0; b < batch_size; b++) {
380 381 382
    if (!old_version) {
      // eigen
      forward_final_outputV2(*context, value, frame_size);
G
guosheng 已提交
383
    } else {
384
      if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
385
          (sizeof(T) == 4)) {
F
Feiyu Chan 已提交
386 387
        hl_avx_gru_forward_final_output(op_final_output,
                                        value.gate_value,
388
                                        value.prev_out_value,
F
Feiyu Chan 已提交
389 390 391 392 393
                                        value.output_value,
                                        frame_size,
                                        active_node,
                                        origin_mode,
                                        old_version);
394
      } else {
F
Feiyu Chan 已提交
395 396 397 398 399 400 401 402
        hl_naive_gru_forward_final_output(op_final_output,
                                          value.gate_value,
                                          value.prev_out_value,
                                          value.output_value,
                                          frame_size,
                                          active_node,
                                          origin_mode,
                                          old_version);
403
      }
G
guosheng 已提交
404
    }
G
guosheng 已提交
405 406 407 408
    value.gate_value += frame_size * 3;
    value.output_value += frame_size;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
409 410 411 412 413
    }
  }
}

template <class OpStateGrad, typename T>
F
Feiyu Chan 已提交
414 415 416 417 418 419
void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad,
                                      T *gate_value,
                                      T *gate_grad,
                                      const T *prev_out_value,
                                      T *prev_out_grad,
                                      T *output_grad,
G
guosheng 已提交
420
                                      int frame_size,
Q
Qiao Longfei 已提交
421 422
                                      ActivationType active_node,
                                      bool origin_mode) {
G
guosheng 已提交
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
  T r_update_gate_value;
  T r_update_gate_grad;
  T r_frame_state_value;
  T r_frame_state_grad;
  T r_out_grad;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T *update_gate_value = gate_value;
  T *update_gate_grad = gate_grad;
  T *frame_state_value = gate_value + frame_size * 2;
  T *frame_state_grad = gate_grad + frame_size * 2;

  for (int i = 0; i < frame_size; i++) {
    r_update_gate_value = update_gate_value[i];
    r_frame_state_value = frame_state_value[i];
    r_out_grad = output_grad[i];
    if (prev_out_value) {
      r_prev_out_value = prev_out_value[i];
G
guosheng 已提交
441
    }
G
guosheng 已提交
442 443
    if (prev_out_grad) {
      r_prev_out_grad = prev_out_grad[i];
G
guosheng 已提交
444 445
    }

F
Feiyu Chan 已提交
446 447 448 449 450 451 452 453 454
    op_state_grad(&r_update_gate_value,
                  &r_update_gate_grad,
                  &r_frame_state_value,
                  &r_frame_state_grad,
                  &r_prev_out_value,
                  &r_prev_out_grad,
                  &r_out_grad,
                  active_node,
                  origin_mode);
G
guosheng 已提交
455

G
guosheng 已提交
456 457 458 459
    update_gate_grad[i] = r_update_gate_grad;
    frame_state_grad[i] = r_frame_state_grad;
    if (prev_out_grad) {
      prev_out_grad[i] = r_prev_out_grad;
G
guosheng 已提交
460 461 462 463 464
    }
  }
}

template <class OpResetGrad, typename T>
F
Feiyu Chan 已提交
465 466 467 468 469 470
void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad,
                                      T *gate_value,
                                      T *gate_grad,
                                      const T *prev_out_value,
                                      T *prev_out_grad,
                                      T *reset_output_grad,
G
guosheng 已提交
471
                                      int frame_size,
Q
Qiao Longfei 已提交
472
                                      ActivationType active_gate) {
G
guosheng 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
  T r_update_gate_value;
  T r_update_gate_grad;
  T r_reset_gate_value;
  T r_reset_gate_grad;
  T r_reset_output_grad = 0;
  T r_prev_out_value = 0;
  T r_prev_out_grad = 0;
  T *update_gate_value = gate_value;
  T *update_gate_grad = gate_grad;
  T *reset_gate_value = gate_value + frame_size;
  T *reset_gate_grad = gate_grad + frame_size;

  for (int i = 0; i < frame_size; i++) {
    r_update_gate_value = update_gate_value[i];
    r_update_gate_grad = update_gate_grad[i];
    r_reset_gate_value = reset_gate_value[i];

    if (prev_out_value && prev_out_grad) {
      r_reset_output_grad = reset_output_grad[i];
G
guosheng 已提交
492
    }
G
guosheng 已提交
493 494
    if (prev_out_value) {
      r_prev_out_value = prev_out_value[i];
G
guosheng 已提交
495
    }
G
guosheng 已提交
496 497
    if (prev_out_grad) {
      r_prev_out_grad = prev_out_grad[i];
G
guosheng 已提交
498 499
    }

F
Feiyu Chan 已提交
500 501 502 503 504 505 506 507
    op_reset_grad(&r_update_gate_value,
                  &r_update_gate_grad,
                  &r_reset_gate_value,
                  &r_reset_gate_grad,
                  &r_prev_out_value,
                  &r_prev_out_grad,
                  &r_reset_output_grad,
                  active_gate);
G
guosheng 已提交
508

G
guosheng 已提交
509 510 511 512
    update_gate_grad[i] = r_update_gate_grad;
    reset_gate_grad[i] = r_reset_gate_grad;
    if (prev_out_grad) {
      prev_out_grad[i] = r_prev_out_grad;
G
guosheng 已提交
513 514 515 516 517
    }
  }
}

template <class OpStateGrad, typename T>
F
Feiyu Chan 已提交
518 519 520 521 522 523 524 525
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad,
                                    T *gate_value,
                                    T *gate_grad,
                                    const T *prev_out_value,
                                    T *prev_out_grad,
                                    T *output_grad,
                                    int frame_size,
                                    ActivationType active_node,
Q
Qiao Longfei 已提交
526
                                    bool origin_mode) {
G
guosheng 已提交
527
#ifdef __AVX__
G
guosheng 已提交
528 529 530 531 532 533 534
  __m256 r_update_gate_value;
  __m256 r_update_gate_grad;
  __m256 r_frame_state_value;
  __m256 r_frame_state_grad;
  __m256 r_out_grad;
  __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
535 536 537 538 539 540
  __m256 *update_gate_value = reinterpret_cast<__m256 *>(gate_value);
  __m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
  __m256 *frame_state_value =
      reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
  __m256 *frame_state_grad =
      reinterpret_cast<__m256 *>(gate_grad + frame_size * 2);
G
guosheng 已提交
541 542 543 544

  for (int i = 0; i < frame_size / 8; i++) {
    r_update_gate_value = update_gate_value[i];
    r_frame_state_value = frame_state_value[i];
545
    r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i];
G
guosheng 已提交
546
    if (prev_out_value) {
547
      r_prev_out_value = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
G
guosheng 已提交
548
    }
G
guosheng 已提交
549
    if (prev_out_grad) {
550
      r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
G
guosheng 已提交
551 552
    }

F
Feiyu Chan 已提交
553 554 555 556 557 558 559 560 561
    op_state_grad(&r_update_gate_value,
                  &r_update_gate_grad,
                  &r_frame_state_value,
                  &r_frame_state_grad,
                  &r_prev_out_value,
                  &r_prev_out_grad,
                  &r_out_grad,
                  active_node,
                  origin_mode);
G
guosheng 已提交
562

G
guosheng 已提交
563 564 565
    update_gate_grad[i] = r_update_gate_grad;
    frame_state_grad[i] = r_frame_state_grad;
    if (prev_out_grad) {
566
      (reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_prev_out_grad;
G
guosheng 已提交
567 568 569 570 571 572
    }
  }
#endif
}

template <class OpResetGrad, typename T>
F
Feiyu Chan 已提交
573 574 575 576 577 578
void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad,
                                    T *gate_value,
                                    T *gate_grad,
                                    const T *prev_out_value,
                                    T *prev_out_grad,
                                    T *reset_output_grad,
Q
Qiao Longfei 已提交
579 580
                                    int frame_size,
                                    ActivationType active_gate) {
G
guosheng 已提交
581
#ifdef __AVX__
G
guosheng 已提交
582 583 584 585 586 587 588
  __m256 r_update_gate_value;
  __m256 r_update_gate_grad;
  __m256 r_reset_gate_value;
  __m256 r_reset_gate_grad;
  __m256 r_reset_output_grad = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_value = _mm256_set1_ps(0.0f);
  __m256 r_prev_out_grad = _mm256_set1_ps(0.0f);
589 590 591 592 593
  __m256 *update_gate_value = reinterpret_cast<__m256 *>(gate_value);
  __m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
  __m256 *reset_gate_value =
      reinterpret_cast<__m256 *>(gate_value + frame_size);
  __m256 *reset_gate_grad = reinterpret_cast<__m256 *>(gate_grad + frame_size);
G
guosheng 已提交
594 595 596 597 598 599 600

  for (int i = 0; i < frame_size / 8; i++) {
    r_update_gate_value = update_gate_value[i];
    r_update_gate_grad = update_gate_grad[i];
    r_reset_gate_value = reset_gate_value[i];

    if (prev_out_value && prev_out_grad) {
601
      r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
G
guosheng 已提交
602
    }
G
guosheng 已提交
603
    if (prev_out_value) {
604
      r_prev_out_value = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
G
guosheng 已提交
605
    }
G
guosheng 已提交
606
    if (prev_out_grad) {
607
      r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
G
guosheng 已提交
608 609
    }

F
Feiyu Chan 已提交
610 611 612 613 614 615 616 617
    op_reset_grad(&r_update_gate_value,
                  &r_update_gate_grad,
                  &r_reset_gate_value,
                  &r_reset_gate_grad,
                  &r_prev_out_value,
                  &r_prev_out_grad,
                  &r_reset_output_grad,
                  active_gate);
G
guosheng 已提交
618

G
guosheng 已提交
619 620 621
    update_gate_grad[i] = r_update_gate_grad;
    reset_gate_grad[i] = r_reset_gate_grad;
    if (prev_out_grad) {
622
      (reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_prev_out_grad;
G
guosheng 已提交
623 624 625 626 627
    }
  }
#endif
}

628
template <class OpGruGrad, typename T>
F
Feiyu Chan 已提交
629 630 631 632 633 634 635 636 637 638
inline void hl_naive_gru_backward(OpGruGrad op_gru_grad,
                                  T *gate_value,
                                  T *gate_grad,
                                  const T *prev_out_value,
                                  T *prev_out_grad,
                                  T *reset_output_value,
                                  T *reset_output_grad,
                                  T *output_grad,
                                  int frame_size,
                                  ActivationType active_node,
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
                                  ActivationType active_gate) {
  T r_value_reset_gate;
  T r_grad_reset_gate;
  T r_value_update_gate;
  T r_grad_update_gate;
  T r_value_frame_state;
  T r_grad_frame_state;
  T r_value_prev_out = 0;
  T r_grad_prev_out = 0;
  T r_grad_output;
  T r_value_reset_output;
  T r_grad_reset_output = 0;
  T *reset_gate_value = gate_value;
  T *reset_gate_grad = gate_grad;
  T *update_gate_value = gate_value + frame_size;
  T *update_gate_grad = gate_grad + frame_size;
  T *frame_state_value = gate_value + 2 * frame_size;
  T *frame_state_grad = gate_grad + 2 * frame_size;

  for (int i = 0; i < frame_size; ++i) {
    r_value_reset_gate = reset_gate_value[i];
    r_grad_reset_gate = reset_gate_grad[i];
    r_value_update_gate = update_gate_value[i];
    r_grad_update_gate = update_gate_grad[i];
    r_value_frame_state = frame_state_value[i];
    r_grad_frame_state = frame_state_grad[i];
    if (prev_out_value) {
      r_value_prev_out = prev_out_value[i];
    }
    if (prev_out_grad) {
      r_grad_prev_out = prev_out_grad[i];
    }
    r_grad_output = output_grad[i];
    r_value_reset_output = reset_output_value[i];
    if (prev_out_value && prev_out_grad) {
      r_grad_reset_output = reset_output_grad[i];
    }

F
Feiyu Chan 已提交
677 678 679 680 681 682 683 684 685 686 687 688
    op_gru_grad(&r_value_reset_gate,
                &r_grad_reset_gate,
                &r_value_update_gate,
                &r_grad_update_gate,
                &r_value_frame_state,
                &r_grad_frame_state,
                &r_value_prev_out,
                &r_grad_prev_out,
                &r_grad_output,
                &r_value_reset_output,
                &r_grad_reset_output,
                active_node,
689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
                active_gate);

    reset_gate_grad[i] = r_grad_reset_gate;
    update_gate_grad[i] = r_grad_update_gate;
    frame_state_grad[i] = r_grad_frame_state;
    if (prev_out_grad) {
      prev_out_grad[i] = r_grad_prev_out;
    }
    if (prev_out_value && prev_out_grad) {
      reset_output_grad[i] = r_grad_reset_output;
    }
  }
}

template <class OpGruGrad, typename T>
F
Feiyu Chan 已提交
704 705 706 707 708 709 710 711 712 713
inline void hl_avx_gru_backward(OpGruGrad op_gru_grad,
                                T *gate_value,
                                T *gate_grad,
                                const T *prev_out_value,
                                T *prev_out_grad,
                                T *reset_output_value,
                                T *reset_output_grad,
                                T *output_grad,
                                int frame_size,
                                ActivationType active_node,
714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755
                                ActivationType active_gate) {
#ifdef __AVX__
  __m256 r_value_reset_gate;
  __m256 r_grad_reset_gate;
  __m256 r_value_update_gate;
  __m256 r_grad_update_gate;
  __m256 r_value_frame_state;
  __m256 r_grad_frame_state;
  __m256 r_value_prev_out = _mm256_set1_ps(0.0f);
  __m256 r_grad_prev_out = _mm256_set1_ps(0.0f);
  __m256 r_grad_output;
  __m256 r_value_reset_output;
  __m256 r_grad_reset_output = _mm256_set1_ps(0.0f);
  __m256 *reset_gate_value = reinterpret_cast<__m256 *>(gate_value);
  __m256 *reset_gate_grad = reinterpret_cast<__m256 *>(gate_grad);
  __m256 *update_gate_value =
      reinterpret_cast<__m256 *>(gate_value + frame_size);
  __m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad + frame_size);
  __m256 *frame_state_value =
      reinterpret_cast<__m256 *>(gate_value + 2 * frame_size);
  __m256 *frame_state_grad =
      reinterpret_cast<__m256 *>(gate_grad + 2 * frame_size);

  for (int i = 0; i < frame_size / 8; ++i) {
    r_value_reset_gate = reset_gate_value[i];
    r_grad_reset_gate = reset_gate_grad[i];
    r_value_update_gate = update_gate_value[i];
    r_grad_update_gate = update_gate_grad[i];
    r_value_frame_state = frame_state_value[i];
    r_grad_frame_state = frame_state_grad[i];
    if (prev_out_value) {
      r_value_prev_out = (reinterpret_cast<const __m256 *>(prev_out_value))[i];
    }
    if (prev_out_grad) {
      r_grad_prev_out = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
    }
    r_grad_output = (reinterpret_cast<__m256 *>(output_grad))[i];
    r_value_reset_output = (reinterpret_cast<__m256 *>(reset_output_value))[i];
    if (prev_out_value && prev_out_grad) {
      r_grad_reset_output = (reinterpret_cast<__m256 *>(reset_output_grad))[i];
    }

F
Feiyu Chan 已提交
756 757 758 759 760 761 762 763 764 765 766 767
    op_gru_grad(&r_value_reset_gate,
                &r_grad_reset_gate,
                &r_value_update_gate,
                &r_grad_update_gate,
                &r_value_frame_state,
                &r_grad_frame_state,
                &r_value_prev_out,
                &r_grad_prev_out,
                &r_grad_output,
                &r_value_reset_output,
                &r_grad_reset_output,
                active_node,
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782
                active_gate);

    reset_gate_grad[i] = r_grad_reset_gate;
    update_gate_grad[i] = r_grad_update_gate;
    frame_state_grad[i] = r_grad_frame_state;
    if (prev_out_grad) {
      (reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_grad_prev_out;
    }
    if (prev_out_value && prev_out_grad) {
      (reinterpret_cast<__m256 *>(reset_output_grad))[i] = r_grad_reset_output;
    }
  }
#endif
}

G
guosheng 已提交
783
template <class OpStateGrad, typename T>
G
guosheng 已提交
784
inline void backward_state_grad(OpStateGrad op_state_grad,
F
Feiyu Chan 已提交
785 786 787 788 789 790
                                phi::funcs::GRUMetaValue<T> value,
                                phi::funcs::GRUMetaGrad<T> grad,
                                int frame_size,
                                int batch_size,
                                ActivationType active_node,
                                bool origin_mode) {
G
guosheng 已提交
791 792
  for (int b = 0; b < batch_size; b++) {
    if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
F
Feiyu Chan 已提交
793 794 795 796 797 798 799 800 801
      hl_avx_gru_backward_state_grad(op_state_grad,
                                     value.gate_value,
                                     grad.gate_grad,
                                     value.prev_out_value,
                                     grad.prev_out_grad,
                                     grad.output_grad,
                                     frame_size,
                                     active_node,
                                     origin_mode);
G
guosheng 已提交
802
    } else {
F
Feiyu Chan 已提交
803 804 805 806 807 808 809 810 811
      hl_naive_gru_backward_state_grad(op_state_grad,
                                       value.gate_value,
                                       grad.gate_grad,
                                       value.prev_out_value,
                                       grad.prev_out_grad,
                                       grad.output_grad,
                                       frame_size,
                                       active_node,
                                       origin_mode);
G
guosheng 已提交
812 813
    }

G
guosheng 已提交
814 815 816
    value.gate_value += frame_size * 3;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
817 818
    }

G
guosheng 已提交
819 820 821 822
    grad.gate_grad += frame_size * 3;
    grad.output_grad += frame_size;
    if (grad.prev_out_grad) {
      grad.prev_out_grad += frame_size;
G
guosheng 已提交
823 824 825 826 827
    }
  }
}

template <class OpResetGrad, typename T>
G
guosheng 已提交
828
inline void backward_reset_grad(OpResetGrad op_reset_grad,
F
Feiyu Chan 已提交
829 830 831 832
                                phi::funcs::GRUMetaValue<T> value,
                                phi::funcs::GRUMetaGrad<T> grad,
                                int frame_size,
                                int batch_size,
Q
Qiao Longfei 已提交
833
                                ActivationType active_gate) {
G
guosheng 已提交
834 835
  for (int b = 0; b < batch_size; b++) {
    if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
F
Feiyu Chan 已提交
836 837 838 839 840 841 842 843
      hl_avx_gru_backward_reset_grad(op_reset_grad,
                                     value.gate_value,
                                     grad.gate_grad,
                                     value.prev_out_value,
                                     grad.prev_out_grad,
                                     grad.reset_output_grad,
                                     frame_size,
                                     active_gate);
G
guosheng 已提交
844
    } else {
F
Feiyu Chan 已提交
845 846 847 848 849 850 851 852
      hl_naive_gru_backward_reset_grad(op_reset_grad,
                                       value.gate_value,
                                       grad.gate_grad,
                                       value.prev_out_value,
                                       grad.prev_out_grad,
                                       grad.reset_output_grad,
                                       frame_size,
                                       active_gate);
G
guosheng 已提交
853 854
    }

G
guosheng 已提交
855 856 857
    value.gate_value += frame_size * 3;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
G
guosheng 已提交
858 859
    }

G
guosheng 已提交
860 861 862 863
    grad.gate_grad += frame_size * 3;
    grad.reset_output_grad += frame_size;
    if (grad.prev_out_grad) {
      grad.prev_out_grad += frame_size;
G
guosheng 已提交
864 865 866 867
    }
  }
}

868 869
template <typename T, typename Context>
inline void gru_backward(const Context &context,
F
Feiyu Chan 已提交
870 871
                         phi::funcs::GRUMetaValue<T> value,
                         phi::funcs::GRUMetaGrad<T> grad,
872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897
                         int frame_size) {
  auto &place = *context.eigen_device();

  auto value_reset_gate =
      typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
  auto grad_reset_gate =
      typename EigenVector<T>::Type(grad.gate_grad, Array1(frame_size));
  auto value_update_gate = typename EigenVector<T>::Type(
      value.gate_value + frame_size, Array1(frame_size));
  auto grad_update_gate = typename EigenVector<T>::Type(
      grad.gate_grad + frame_size, Array1(frame_size));
  auto value_frame_state = typename EigenVector<T>::Type(
      value.gate_value + frame_size * 2, Array1(frame_size));
  auto grad_frame_state = typename EigenVector<T>::Type(
      grad.gate_grad + frame_size * 2, Array1(frame_size));

  auto grad_output =
      typename EigenVector<T>::Type(grad.output_grad, Array1(frame_size));
  auto value_reset_output = typename EigenVector<T>::Type(
      value.reset_output_value, Array1(frame_size));
  auto grad_reset_output =
      typename EigenVector<T>::Type(grad.reset_output_grad, Array1(frame_size));

  if (value.prev_out_value) {
    auto value_prev_out = typename EigenVector<T>::ConstType(
        value.prev_out_value, Array1(frame_size));
898 899 900 901 902
    SigmoidGradFunctor<T>()(place,
                            1 /*useless*/,
                            value_update_gate,
                            (value_prev_out - value_frame_state) * grad_output,
                            grad_update_gate);
903
  } else {
904
    SigmoidGradFunctor<T>()(
F
Feiyu Chan 已提交
905 906 907 908 909
        place,
        1 /*useless*/,
        value_update_gate,
        static_cast<T>(-1) * value_frame_state * grad_output,
        grad_update_gate);
910 911 912 913 914 915 916
  }
  if (grad.prev_out_grad) {
    auto grad_prev_out =
        typename EigenVector<T>::Type(grad.prev_out_grad, Array1(frame_size));
    grad_prev_out.device(place) =
        grad_prev_out + grad_output * value_update_gate;
  }
917 918 919 920 921 922
  TanhGradFunctor<T>()(place,
                       1 /*useless*/,
                       value_frame_state,
                       grad_output * (static_cast<T>(1.0) - value_update_gate),
                       grad_frame_state);
  SigmoidGradFunctor<T>()(
F
Feiyu Chan 已提交
923 924 925
      place,
      1 /*useless*/,
      value_reset_gate,
926 927 928 929 930 931 932
      value_reset_output / value_reset_gate * grad_frame_state,
      grad_reset_gate);
  if (value.prev_out_value && grad.prev_out_grad) {
    grad_reset_output.device(place) = value_reset_gate * grad_frame_state;
  }
}

933 934
template <class OpGruGrad, typename T, typename Context>
inline void cpu_gru_backward(const Context &context,
F
Feiyu Chan 已提交
935 936 937 938 939 940
                             OpGruGrad op_gru_grad,
                             phi::funcs::GRUMetaValue<T> value,
                             phi::funcs::GRUMetaGrad<T> grad,
                             int frame_size,
                             int batch_size,
                             ActivationType active_node,
941 942
                             ActivationType active_gate) {
  for (int b = 0; b < batch_size; ++b) {
943 944
    // eigen
    gru_backward(context, value, grad, frame_size);
945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960

    value.gate_value += frame_size * 3;
    value.reset_output_value += frame_size;
    if (value.prev_out_value) {
      value.prev_out_value += frame_size;
    }

    grad.gate_grad += frame_size * 3;
    grad.output_grad += frame_size;
    grad.reset_output_grad += frame_size;
    if (grad.prev_out_grad) {
      grad.prev_out_grad += frame_size;
    }
  }
}

961
#endif  // @} End Group for GRU CPU
G
guosheng 已提交
962 963

}  // namespace detail
F
Feiyu Chan 已提交
964 965
}  // namespace funcs
}  // namespace phi