rnn_op.h 85.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/gru_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/operators/utils.h"

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
using TensorList = std::vector<framework::Tensor>;

#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR)                      \
  inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \
    const std::string& mode = ctx.Attr<std::string>("mode");           \
    return mode == #MODE_STR;                                          \
  }

DEFINE_MODE_DETECTOR(lstm, LSTM);
DEFINE_MODE_DETECTOR(gru, GRU);
DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU);
DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH);

void SwapPoniter(Tensor** a, Tensor** b) {
  Tensor* c = *a;
  *a = *b;
  *b = c;
}

template <typename T>
void create_mask_matrix(const framework::ExecutionContext& context,
                        const Tensor* sequence_length, Tensor* mask_matrix,
                        const bool& is_reverse, int* min_seq_len) {
  const auto& seq_len_vec = GetDataFromTensor<int>(sequence_length);
  const int& table_width = mask_matrix->dims()[0];
  Tensor temp;
  temp.Resize(
      framework::make_ddim({mask_matrix->dims()[1], mask_matrix->dims()[0]}));
  T* data_temp = temp.mutable_data<T>(context.GetPlace());
  std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast<T>(1.0));
  *min_seq_len = table_width;
  for (unsigned int i = 0; i < seq_len_vec.size(); i++) {
    // reset the mask matrix
    *min_seq_len = std::min(seq_len_vec[i], *min_seq_len);
    if (seq_len_vec[i] == table_width) {
      continue;
    }
    if (is_reverse) {
      std::fill(data_temp + i * table_width,
                data_temp + (i + 1) * table_width - seq_len_vec[i],
                static_cast<T>(0));
    } else {
      std::fill(data_temp + i * table_width + seq_len_vec[i],
                data_temp + (i + 1) * table_width, static_cast<T>(0));
    }
  }
  mask_matrix->mutable_data<T>(context.GetPlace());
  std::vector<int> trans_vec;
  trans_vec.emplace_back(1);
  trans_vec.emplace_back(0);
  auto& dev_ctx = context.template device_context<platform::CPUDeviceContext>();
  TransCompute<platform::CPUDeviceContext, T>(2, dev_ctx, temp, mask_matrix,
                                              trans_vec);
}

template <typename T>
struct Cell {
  virtual ~Cell() {}
  virtual void operator()(const platform::CPUDeviceContext* device_ctx,
                          Tensor* input, const Tensor* weight_hh,
                          const Tensor* init_h, const Tensor* init_c,
                          Tensor* last_h, Tensor* last_c, Tensor* last_c_act,
                          Tensor* output, const Tensor* bias_hh,
                          Tensor* weight_hh_gru) const {}
};

template <typename T, template <typename> class EigenActivationFunctor,
          math::detail::ActivationType act_type>
struct SimpleRNNCell : Cell<T> {
  void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
                  const Tensor* weight_hh, const Tensor* init_h,
                  const Tensor* init_c, Tensor* last_h, Tensor* last_c,
                  Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
                  Tensor* weight_hh_gru) const override {
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
    auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
    auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
    mat_dim_a.height_ *= mat_dim_a.batch_size_;
    mat_dim_a.batch_size_ = 0;
    // convert the batch matmul to matmul, this operator could be speed faster
    blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
                input, static_cast<T>(1.0));
    auto z = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(input, "Input", "z", "Activation"));
    auto hidden = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(output, "Output", "hidden", "Activation"));

    auto* place = device_ctx->eigen_device();
    EigenActivationFunctor<T> functor;
    functor(*place, z, hidden);
  }
};

template <typename T>
struct GRUCell : Cell<T> {
  void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
                  const Tensor* weight_hh, const Tensor* init_h,
                  const Tensor* init_c, Tensor* last_h, Tensor* last_c,
                  Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
                  Tensor* weight_hh_gru) const override {
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
    auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
    auto mat_dim_b =
        math::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true);
    mat_dim_a.height_ *= mat_dim_a.batch_size_;
    mat_dim_a.batch_size_ = 0;
    // convert the batch matmul to matmul, this operator could be speed faster
    blas.MatMul(*init_h, mat_dim_a, *weight_hh_gru, mat_dim_b,
                static_cast<T>(1.0), input, static_cast<T>(1.0));
    size_t frame_size = init_h->dims()[2];
    size_t batch_size = init_h->dims()[1];

    math::GRUMetaValue<T> gru_value;
    gru_value.gate_weight = weight_hh->data<T>();
    gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
    gru_value.reset_bias = bias_hh->data<T>() + 2 * frame_size;

    gru_value.gate_value = input->data<T>();
    gru_value.reset_output_value = last_c->data<T>();
    gru_value.output_value = output->data<T>();
    gru_value.prev_out_value = init_h->data<T>();

    auto gate_act = math::detail::GetActivationType("sigmoid_v2");
    auto cand_act = math::detail::GetActivationType("tanh_v2");

    math::GRUUnitFunctorV2<platform::CPUDeviceContext, T>::compute(
        *device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act);
  }
};

template <typename T>
struct LSTMCell : Cell<T> {
  void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input,
                  const Tensor* weight_hh, const Tensor* init_h,
                  const Tensor* init_c, Tensor* last_h, Tensor* last_c,
                  Tensor* last_c_act, Tensor* output, const Tensor* bias_hh,
                  Tensor* weight_hh_gru) const override {
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(*device_ctx);
    auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false);
    auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true);
    mat_dim_a.height_ *= mat_dim_a.batch_size_;
    mat_dim_a.batch_size_ = 0;
    // convert the batch matmul to matmul, this operator could be speed faster
    blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast<T>(1.0),
                input, static_cast<T>(1.0));

    math::LstmMetaValue<T> lstm_value;
    lstm_value.check_ig = nullptr;
    lstm_value.check_fg = nullptr;
    lstm_value.check_og = nullptr;

    auto gate_act = math::detail::GetActivationType("sigmoid_v2");
    auto cell_act = math::detail::GetActivationType("tanh_v2");
    auto cand_act = math::detail::GetActivationType("tanh_v2");

    size_t frame_size = init_h->dims()[2];
    size_t batch_size = init_h->dims()[1];

    Tensor cell_pre_act;
    if (last_c_act == nullptr) { /* is test */
      cell_pre_act.mutable_data<T>(init_h->dims(), device_ctx->GetPlace());
      last_c_act = &cell_pre_act;
    }

    lstm_value.prev_state_value = init_c->data<T>();
    lstm_value.gate_value = input->data<T>();
    lstm_value.output_value = output->data<T>();
    lstm_value.state_value = last_c->data<T>();
    lstm_value.state_active_value = last_c_act->data<T>();
    T cell_clip = 0.0;
    math::LstmUnitFunctor<platform::CPUDeviceContext, T>::compute(
        *device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act,
        cell_act, cand_act, false);
  }
};

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
template <typename T>
void dropout_helper(const framework::ExecutionContext& context, Tensor* x,
                    Tensor* y, const Tensor* mask, const float& dropout_prob) {
  auto& place = *context.template device_context<platform::CPUDeviceContext>()
                     .eigen_device();
  auto dropout_mask = EigenVector<uint8_t>::Flatten(*mask);
  auto in = EigenVector<T>::Flatten(*x);
  auto out = EigenVector<T>::Flatten(*y);
  if (dropout_prob == 1.0f) {
    out.device(place) = static_cast<T>(0) * in;
  } else {
    out.device(place) =
        in * dropout_mask.cast<T>() / static_cast<T>(1.0f - dropout_prob);
  }
}

229 230
template <typename T>
void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
231
                                  Tensor* x, Tensor* y, Tensor* mask,
232
                                  const float& dropout_prob,
233
                                  const int& seed_number, bool is_test,
234 235 236 237 238 239 240 241 242
                                  bool* is_has_reset) {
  if (is_test) {
    return;
  }
  size_t size = framework::product(x->dims());
  auto* mask_data = mask->data<uint8_t>();
  if (!(*is_has_reset)) {
    // Special case when dropout_prob is 1.0
    if (dropout_prob == 1.0f) {
243 244 245 246 247 248 249 250 251 252
      std::fill(mask_data, mask_data + size, static_cast<uint8_t>(0));
    } else {
      auto engine = framework::GetCPURandomEngine(seed_number);
      std::uniform_real_distribution<float> dist(0, 1);
      for (size_t i = 0; i < size; ++i) {
        if (dist(*engine) < dropout_prob) {
          mask_data[i] = 0;
        } else {
          mask_data[i] = 1;
        }
253 254 255 256
      }
    }
    *is_has_reset = true;
  }
257
  dropout_helper<T>(context, x, y, mask, dropout_prob);
258 259 260 261 262 263
}

template <typename T>
void dropout_cpu_grad_function_inplace(
    const framework::ExecutionContext& context, Tensor* grad_x,
    const Tensor* mask, const float& dropout_prob) {
264
  dropout_helper<T>(context, grad_x, grad_x, mask, dropout_prob);
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
}

template <typename T, typename CellType>
struct Layer {
  explicit Layer(const CellType& cell) : cell_(cell) {}
  virtual ~Layer() {}
  void preprocess(const framework::ExecutionContext& context,
                  const Tensor* input, const Tensor& weight,
                  const Tensor& bias_ih, const Tensor& bias_hh,
                  Tensor* cache_input, bool is_test) {
    // crate the temp input for the X * W_ih^T + Bias_ih
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    const int& hidden_size = weight.dims()[0];
    cache_input->Resize(framework::make_ddim(
        {input->dims()[0], input->dims()[1], hidden_size}));
    if (is_test) {
      cache_input->mutable_data<T>(context.GetPlace());
    }
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
    auto mat_dim_a = math::CreateMatrixDescriptor(input->dims(), 0, false);
    auto mat_dim_b = math::CreateMatrixDescriptor(weight.dims(), 0, true);
    // convert the batch matmul to matmul, this operator could be speed faster
    mat_dim_a.height_ *= mat_dim_a.batch_size_;
    mat_dim_a.batch_size_ = 0;
    blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast<T>(1.0),
                cache_input, static_cast<T>(0));

293
    auto in = framework::EigenMatrix<T>::Reshape(
294
        *cache_input, cache_input->dims().size() - 1);
295
    auto bias_ih_tmp = framework::EigenMatrix<T>::From(
296 297 298
        bias_ih, framework::make_ddim({1, bias_ih.dims()[0]}));
    const int& row_num =
        framework::product(cache_input->dims()) / cache_input->dims()[2];
299
    in = in + bias_ih_tmp.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
300 301 302 303 304 305 306 307 308 309 310
    if (is_gru(context)) {
      // reset_gate update_gate cell_gate = [1, 1, 0]
      Tensor bias_hh_tmp;
      bias_hh_tmp.Resize({bias_hh.numel()});
      bias_hh_tmp.mutable_data<T>(context.GetPlace());
      framework::TensorCopy(bias_hh, context.GetPlace(), dev_ctx, &bias_hh_tmp);
      bias_hh_tmp.Resize({3, bias_hh_tmp.numel() / 3});
      auto bias_hh_tmp_unbind = Unbind(bias_hh_tmp);
      math::SetConstant<platform::CPUDeviceContext, T> zero;
      zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast<T>(0.0));

311
      auto bias_hh_after_mask = framework::EigenMatrix<T>::From(
312
          bias_hh_tmp, framework::make_ddim({1, bias_hh.dims()[0]}));
313
      in = in + bias_hh_after_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
314
    } else {
315
      auto bias_hh_no_mask = framework::EigenMatrix<T>::From(
316
          bias_hh, framework::make_ddim({1, bias_hh.dims()[0]}));
317
      in = in + bias_hh_no_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
318 319 320 321 322 323 324 325 326
    }
  }

  void postprocess(const framework::ExecutionContext& context, Tensor* output,
                   const Tensor* init_h, const Tensor* init_c, Tensor* last_h,
                   Tensor* last_c, const Tensor& mask_tensor) {
    // in the output, if mask flag is 0, we will retun the zero data
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
327
    auto out =
328
        framework::EigenMatrix<T>::Reshape(*output, output->dims().size() - 1);
329
    auto mask = framework::EigenMatrix<T>::From(
330
        mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
331
    auto pre_h =
332
        framework::EigenMatrix<T>::Reshape(*init_h, init_h->dims().size() - 1);
333
    auto curr_h =
334
        framework::EigenMatrix<T>::Reshape(*last_h, last_h->dims().size() - 1);
335 336 337 338
    auto mask_broadcast =
        mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[2]));
    curr_h.device(place) = out * mask_broadcast + pre_h * (1 - mask_broadcast);
    out.device(place) = out * mask_broadcast;
339 340

    if (is_lstm(context)) {
341
      auto pre_c = framework::EigenMatrix<T>::Reshape(
342
          *init_c, init_c->dims().size() - 1);
343
      auto curr_c = framework::EigenMatrix<T>::Reshape(
344
          *last_c, last_c->dims().size() - 1);
345 346
      curr_c.device(place) =
          curr_c * mask_broadcast + pre_c * (1 - mask_broadcast);
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
    }
  }

  virtual void operator()(const framework::ExecutionContext& context,
                          const Tensor* input, const TensorList& vec,
                          const TensorList& init_h, const TensorList& init_c,
                          const Tensor* sequence_length, TensorList last_h,
                          TensorList last_c, Tensor* output,
                          const int& layer_idx, const int& gate_num,
                          Tensor* gate_value, Tensor* cell_value,
                          Tensor* cell_act_value, bool is_test) {}

  void RunTestIter(const framework::ExecutionContext& context,
                   const Tensor* input, const TensorList& vec,
                   const TensorList& init_h, const TensorList& init_c,
                   const Tensor* sequence_length, TensorList* last_h_ptr,
                   TensorList* last_c_ptr, Tensor* output, int layer_idx,
                   Tensor* gate_value, Tensor* cell_value,
                   Tensor* cell_act_value, bool is_bidirect, int offset) {
    bool is_reverse = false;
    if (is_bidirect) {
      layer_idx = 2 * layer_idx + offset;
      if (offset > 0) {
        is_reverse = true;
      }
    }
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    const int& time_step = input->dims()[0];
    this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
                     vec[3 + offset * 4], gate_value, true);
    auto input_tensors = Unbind(*gate_value);
    auto output_tensors = Unbind(*output);
    if (is_reverse) {
      std::reverse(input_tensors.begin(), input_tensors.end());
      std::reverse(output_tensors.begin(), output_tensors.end());
    }
    TensorList mask_tensor_list;
    // construct the mask matrix for the mask
    bool has_sequence_length = false;
    if (sequence_length != nullptr) {
      has_sequence_length = true;
    }
    Tensor mask_matrix;
    int mask_min_length = time_step;
    if (has_sequence_length) {
      mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));

      create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
                            &mask_min_length);
      mask_tensor_list = Unbind(mask_matrix);
    }
    if (is_reverse) {
      mask_min_length = mask_min_length - time_step + 1;
    }
    bool has_allocate_mem_c = false;
    bool has_use_last_h_holder = false;
    const int& reverse_flag = is_reverse ? -1 : 1;

    // define the init_h holder for the swap
    Tensor init_h_temp;
    framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
                          &init_h_temp);
    Tensor* init_h_holder = &init_h_temp;
    Tensor* last_h_holder = nullptr;
    if (0 < mask_min_length) {
      last_h_holder = &(output_tensors[0]);
    } else {
      last_h_holder = &(*last_h_ptr)[layer_idx];
      has_use_last_h_holder = true;
    }

    Tensor* init_c_holder = nullptr;
    const Tensor* init_c_temp_holder = nullptr;
    Tensor init_c_temp;
    Tensor* last_c_holder = nullptr;
    Tensor last_c_temp;

    if (is_lstm(context)) {
      last_c_holder = &(*last_c_ptr)[layer_idx];
      init_c_temp_holder = &init_c[layer_idx];
    } else if (is_gru(context)) {
      // for reset output value
      last_c_temp.Resize(init_h[layer_idx].dims());
      last_c_temp.mutable_data<T>(context.GetPlace());
      last_c_holder = &last_c_temp;
    }
    Tensor weight_hh_tmp;  // for gru
    if (is_gru(context)) {
      weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
      weight_hh_tmp.mutable_data<T>(context.GetPlace());
      framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
                            &weight_hh_tmp);
      weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
      auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
      math::SetConstant<platform::CPUDeviceContext, T> zero;
      zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
      weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
    }
    for (int i = 0; i < time_step; i++) {
      bool in_mask = (reverse_flag * i) >= mask_min_length;
      if (i > 0) {
        if (!has_allocate_mem_c) {
          if (is_lstm(context) || is_gru(context)) {
            init_c_temp.Resize(init_h[layer_idx].dims());
            init_c_temp.mutable_data<T>(context.GetPlace());
            init_c_holder = &init_c_temp;
          }
          has_allocate_mem_c = true;
        }
        SwapPoniter(&init_c_holder, &last_c_holder);
        init_c_temp_holder = init_c_holder;
      }
      cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
            init_c_temp_holder, last_h_holder, last_c_holder, nullptr,
            &output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
            &weight_hh_tmp);
      if (in_mask) {
        this->postprocess(context, &output_tensors[i], init_h_holder,
                          init_c_temp_holder, last_h_holder, last_c_holder,
                          mask_tensor_list[i]);
      }
      // prepare next step
      if (i + 1 < time_step) {
        bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
        if (next_step_mask) {
          if (!has_use_last_h_holder) {
            init_h_holder = &(*last_h_ptr)[layer_idx];
          }
        } else {
          init_h_holder = &(output_tensors[i + 1]);
        }
        SwapPoniter(&init_h_holder, &last_h_holder);
      }
    }
    if (has_sequence_length) {
      if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
        framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
                              &(*last_h_ptr)[layer_idx]);
      }
    } else {
      framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
                            dev_ctx, &(*last_h_ptr)[layer_idx]);
    }

    if (time_step % 2 == 0) {
      if (is_lstm(context)) {
        framework::TensorCopy(*last_c_holder, context.GetPlace(), dev_ctx,
                              &(*last_c_ptr)[layer_idx]);
      }
    }
  }

  void RunIter(const framework::ExecutionContext& context, const Tensor* input,
               const TensorList& vec, const TensorList& init_h,
               const TensorList& init_c, const Tensor* sequence_length,
               TensorList* last_h_ptr, TensorList* last_c_ptr, Tensor* output,
               int layer_idx, Tensor* gate_value, Tensor* cell_value,
               Tensor* cell_act_value, bool is_bidirect, int offset,
               bool is_test) {
    if (is_test) {
      RunTestIter(context, input, vec, init_h, init_c, sequence_length,
                  last_h_ptr, last_c_ptr, output, layer_idx, gate_value,
                  cell_value, cell_act_value, is_bidirect, offset);
      return;
    }
    bool is_reverse = false;
    if (is_bidirect) {
      layer_idx = 2 * layer_idx + offset;
      if (offset > 0) {
        is_reverse = true;
      }
    }
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    const int& time_step = input->dims()[0];
    this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4],
                     vec[3 + offset * 4], gate_value, is_test);
    auto input_tensors = Unbind(*gate_value);
    auto output_tensors = Unbind(*output);
    if (is_reverse) {
      std::reverse(input_tensors.begin(), input_tensors.end());
      std::reverse(output_tensors.begin(), output_tensors.end());
    }
    TensorList mask_tensor_list;
    // construct the mask matrix for the mask
    bool has_sequence_length = false;
    if (sequence_length != nullptr) {
      has_sequence_length = true;
    }
    Tensor mask_matrix;
    int mask_min_length = time_step;
    if (has_sequence_length) {
      mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
      create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
                            &mask_min_length);
      mask_tensor_list = Unbind(mask_matrix);
    }
    if (is_reverse) {
      mask_min_length = mask_min_length - time_step + 1;
    }

    // define the init_h holder for the swap
    bool has_use_last_h_holder = false;
    const int& reverse_flag = is_reverse ? -1 : 1;

    TensorList cell_value_tensors;
    TensorList cell_act_value_tensors;

    Tensor init_h_temp;
    framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx,
                          &init_h_temp);
    Tensor* init_h_holder = &init_h_temp;
    Tensor* last_h_holder = nullptr;
    if (0 < mask_min_length) {
      last_h_holder = &(output_tensors[0]);
    } else {
      last_h_holder = &(*last_h_ptr)[layer_idx];
      has_use_last_h_holder = true;
    }

    const Tensor* init_c_holder = nullptr;
    Tensor* last_c_holder = nullptr;
    Tensor* last_c_act_holder = nullptr;
    if (is_lstm(context) || is_gru(context)) {
      cell_value->Resize({time_step, cell_value->numel() / time_step});
      cell_value_tensors = Unbind(*cell_value);
      if (is_lstm(context)) {
        cell_act_value->Resize(
            {time_step, cell_act_value->numel() / time_step});
        cell_act_value_tensors = Unbind(*cell_act_value);
      }
    }
    Tensor weight_hh_tmp;  // for gru
    if (is_gru(context)) {
      weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
      weight_hh_tmp.mutable_data<T>(context.GetPlace());
      framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx,
                            &weight_hh_tmp);
      weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3});
      auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp);
      math::SetConstant<platform::CPUDeviceContext, T> zero;
      zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast<T>(0.0));
      weight_hh_tmp.Resize(vec[1 + offset * 4].dims());
    }
    for (int i = 0; i < time_step; i++) {
      bool in_mask = (reverse_flag * i) >= mask_min_length;
      if (is_lstm(context)) {
        if (i == 0) {
          init_c_holder = &init_c[layer_idx];
        } else {
          init_c_holder = &cell_value_tensors[i - 1];
        }
        cell_value_tensors[i].Resize(init_c[layer_idx].dims());
        cell_act_value_tensors[i].Resize(init_c[layer_idx].dims());
        last_c_holder = &cell_value_tensors[i];
        last_c_act_holder = &cell_act_value_tensors[i];
      } else if (is_gru(context)) {
        cell_value_tensors[i].Resize(init_h[layer_idx].dims());
        last_c_holder = &cell_value_tensors[i];
      }

      cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder,
            init_c_holder, last_h_holder, last_c_holder, last_c_act_holder,
            &output_tensors[i], &vec[3 + offset * 4] /* bias_hh */,
            &weight_hh_tmp);
      if (in_mask) {
        this->postprocess(context, &output_tensors[i], init_h_holder,
                          init_c_holder, last_h_holder, last_c_holder,
                          mask_tensor_list[i]);
      }
      // prepare next step
      if (i + 1 < time_step) {
        bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length;
        if (next_step_mask) {
          if (!has_use_last_h_holder) {
            init_h_holder = &(*last_h_ptr)[layer_idx];
          }
        } else {
          init_h_holder = &(output_tensors[i + 1]);
        }
        SwapPoniter(&init_h_holder, &last_h_holder);
      }
    }
    if (has_sequence_length) {
      if (last_h_holder != &(*last_h_ptr)[layer_idx]) {
        framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx,
                              &(*last_h_ptr)[layer_idx]);
      }
    } else {
      framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(),
                            dev_ctx, &(*last_h_ptr)[layer_idx]);
    }
    if (is_lstm(context)) {
      framework::TensorCopy(cell_value_tensors[time_step - 1],
                            context.GetPlace(), dev_ctx,
                            &(*last_c_ptr)[layer_idx]);
    }
  }
  // Cell for the rnn module
  CellType cell_;
};

template <typename T, typename CellType>
struct SingleLayer : public Layer<T, CellType> {
  explicit SingleLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
  void operator()(const framework::ExecutionContext& context,
                  const Tensor* input, const TensorList& vec,
                  const TensorList& init_h, const TensorList& init_c,
                  const Tensor* sequence_length, TensorList last_h,
                  TensorList last_c, Tensor* output, const int& layer_idx,
                  const int& gate_num, Tensor* gate_value, Tensor* cell_value,
                  Tensor* cell_act_value, bool is_test) {
    this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
                  &last_c, output, layer_idx, gate_value, cell_value,
                  cell_act_value, false, 0, is_test);
  }
};

template <typename T, typename CellType>
struct BidirLayer : public Layer<T, CellType> {
  explicit BidirLayer(const CellType& cell) : Layer<T, CellType>(cell) {}
  void operator()(const framework::ExecutionContext& context,
                  const Tensor* input, const TensorList& vec,
                  const TensorList& init_h, const TensorList& init_c,
                  const Tensor* sequence_length, TensorList last_h,
                  TensorList last_c, Tensor* output, const int& layer_idx,
                  const int& gate_num, Tensor* gate_value, Tensor* cell_value,
                  Tensor* cell_act_value, bool is_test) {
    TensorList output_vec(2);
    Tensor forward_input_w, forward_cell_value, forward_cell_act_value;
    Tensor backward_input_w, backward_cell_value, backward_cell_act_value;
    int time_step = input->dims()[0];
    int batch_size = input->dims()[1];
    int hidden_size = output->dims()[2];
    for (int i = 0; i < 2; ++i) {
      output_vec[i].Resize({time_step, batch_size, hidden_size / 2});
      output_vec[i].mutable_data<T>(context.GetPlace());
    }
    if (!is_test) {
      gate_value->Resize({2, gate_value->numel() / 2});
      forward_input_w = gate_value->Slice(0, 1);
      backward_input_w = gate_value->Slice(1, 2);

      if (is_lstm(context) || is_gru(context)) /* for lstm and gru */ {
        cell_value->Resize({2, cell_value->numel() / 2});
        cell_act_value->Resize({2, cell_act_value->numel() / 2});
        forward_cell_value = cell_value->Slice(0, 1);
        backward_cell_value = cell_value->Slice(1, 2);
        if (is_lstm(context)) {
          forward_cell_act_value = cell_act_value->Slice(0, 1);
          backward_cell_act_value = cell_act_value->Slice(1, 2);
        }
      }
    }

    this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
                  &last_c, &output_vec[0], layer_idx, &forward_input_w,
                  &forward_cell_value, &forward_cell_act_value, true, 0,
                  is_test);

    this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h,
                  &last_c, &output_vec[1], layer_idx, &backward_input_w,
                  &backward_cell_value, &backward_cell_act_value, true, 1,
                  is_test);

    // concat the the output result
    auto& dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    paddle::operators::math::ConcatFunctor<platform::CPUDeviceContext, T>
        concat_functor;
    concat_functor(dev_ctx, output_vec, static_cast<int>(2), output);
  }
};

template <typename TensorType>
void SplitReserveData(const framework::ExecutionContext& ctx,
                      TensorType* reserve_data, Tensor* gate_data,
                      Tensor* cell_data, Tensor* cell_act_data,
                      Tensor* hidden_data, int direction_num,
                      const int& time_step, const int& batch_size,
                      const int& hidden_size, const int& gate_num,
                      const int& num_layers) {
  const int& gate_data_idx = gate_num * num_layers;
  const int& cell_data_idx = (gate_num + 1) * num_layers;
  const int& cell_act_data_idx = (gate_num + 2) * num_layers;
  // simple rnn
  int hidden_data_start_idx = gate_data_idx;
  *gate_data = reserve_data->Slice(0, gate_data_idx);
  if (is_lstm(ctx)) {
    *cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
    *cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx);
    hidden_data_start_idx = cell_act_data_idx;
  } else if (is_gru(ctx)) {
    *cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
    hidden_data_start_idx = cell_data_idx;
  }
  int hidden_data_idx = hidden_data_start_idx + (num_layers - 1);
  if (num_layers > 1) {
    *hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx);
  }
}

template <typename TensorType>
void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec,
                            const int& num_layers, const int& gate_num,
                            const bool& is_bidirec,
                            std::vector<TensorList>* params_vec) {
  // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
  // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
  // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
  const int& direction_num = is_bidirec ? 2 : 1;
  const int& layer_weight_size = 4 * direction_num;
  const int& all_weight_size = num_layers * layer_weight_size;
  const int& bias_start_idx = all_weight_size / 2;
  for (int i = 0; i < num_layers; i++) {
    TensorList tensor_list;
    tensor_list.reserve(layer_weight_size);
    for (int j = 0; j < layer_weight_size; j++) {
      Tensor tensor_holder;
      tensor_list.emplace_back(tensor_holder);
    }
    for (int j = 0; j < layer_weight_size; j++) {
      int k = j % 4;
      const int& section = j / 4;
      int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
      if (k >= 2) {
        tensor_idx += bias_start_idx;
      }
      tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
    }
    params_vec->emplace_back(tensor_list);
  }
}

template <typename CellType, typename T>
void AllocateReserveData(const framework::ExecutionContext& ctx,
                         Tensor* reserve_data, Tensor* gate_data,
                         Tensor* cell_data, Tensor* cell_act_data,
                         Tensor* hidden_data, const Tensor* input,
                         bool is_bidirec, int num_layers, int gate_num,
                         int hidden_size) {
  const int& direction_num = is_bidirec ? 2 : 1;
  const int& time_step = input->dims()[0];
  const int& batch_size = input->dims()[1];
  const int& block_size = direction_num * time_step * batch_size * hidden_size;
  int hidden_data_idx = (num_layers - 1);
  if (is_lstm(ctx)) {
    hidden_data_idx += (gate_num + 2) * num_layers;
  } else if (is_gru(ctx)) {
    hidden_data_idx += (gate_num + 1) * num_layers;
  } else {
    hidden_data_idx += gate_num * num_layers;
  }

  reserve_data->Resize({hidden_data_idx, block_size});
  reserve_data->mutable_data<T>(ctx.GetPlace());
  SplitReserveData(ctx, reserve_data, gate_data, cell_data, cell_act_data,
                   hidden_data, direction_num, time_step, batch_size,
                   hidden_size, gate_num, num_layers);
}

template <typename CellType, template <typename, typename> class LayerT,
          template <typename, typename> class SingleLayerT,
          template <typename, typename> class BidirLayerT, typename T>
void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input,
             const std::vector<const Tensor*> weight_list, const Tensor* init_h,
             const Tensor* init_c, const Tensor* sequence_length,
             Tensor* last_h, Tensor* last_c, Tensor* output,
             Tensor* dropout_mask, const int& num_layers, const int& gate_num,
             const int& input_size, const int& hidden_size,
             const bool& is_bidirec, const std::string& cell_type,
819
             const float& dropout_prob, bool is_test, const int& seed,
820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901
             Tensor* reserve_data) {
  const int& direction_num = is_bidirec ? 2 : 1;
  const auto& init_h_dims = init_h->dims();
  PADDLE_ENFORCE_EQ(init_h_dims[0], num_layers * direction_num,
                    platform::errors::InvalidArgument(
                        "The num_layers of in RNN layer must be the same as "
                        "first dim of init hidden, but received"
                        " num_layers:%d, dim:%d",
                        num_layers, init_h_dims[0]));
  if (is_lstm(ctx)) {
    const auto& init_c_dims = init_c->dims();
    PADDLE_ENFORCE_EQ(init_c_dims[0], num_layers * direction_num,
                      platform::errors::InvalidArgument(
                          "The num_layers of in RNN layer must be the same as "
                          "first dim of cell state hidden, but received"
                          " num_layers:%d, dim:%d",
                          num_layers, init_h_dims[0]));
  }
  CellType cell;

  std::vector<TensorList> parameter_lists;
  parameter_lists.reserve(num_layers);
  reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
                         &parameter_lists);

  Tensor gate_data, cell_data, cell_act_data, hidden_data;

  if (!is_test) {
    AllocateReserveData<CellType, T>(
        ctx, reserve_data, &gate_data, &cell_data, &cell_act_data, &hidden_data,
        input, is_bidirec, num_layers, gate_num, hidden_size);
    gate_data.Resize({num_layers, gate_data.numel() / num_layers});
    cell_data.Resize({num_layers, cell_data.numel() / num_layers});
    cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers});

    if (num_layers > 1) {
      hidden_data.Resize(
          {num_layers - 1, hidden_data.numel() / (num_layers - 1)});
    }
  }
  Tensor* input_holder;
  Tensor* output_holder = output;
  Tensor temp;
  bool has_allocate_mem = false;

  auto init_h_unbind = Unbind(*init_h);
  auto last_h_unbind = Unbind(*last_h);
  TensorList init_c_unbind, last_c_unbind;
  if (is_lstm(ctx)) {
    init_c_unbind = Unbind(*init_c);
    last_c_unbind = Unbind(*last_c);
  }

  Tensor curr_gate_data, curr_cell_data, curr_cell_act_data;
  Tensor curr_hidden_data, prev_hidden_data;
  bool has_dropout_reset = false;
  for (int i = 0; i < num_layers; i++) {
    if (!is_test) {
      if (cell_data.numel() > 0) /** for lstm, gru **/ {
        curr_cell_data = cell_data.Slice(i, i + 1);
      }
      if (cell_act_data.numel() > 0) /*for lstm*/ {
        curr_cell_act_data = cell_act_data.Slice(i, i + 1);
      }
      curr_gate_data = gate_data.Slice(i, i + 1);
      output_holder = output;
      if (i < num_layers - 1 && num_layers > 1) {
        curr_hidden_data = hidden_data.Slice(i, i + 1);
        curr_hidden_data.Resize(output->dims());
        output_holder = &curr_hidden_data;
      }
    }
    if (i > 0) {
      if (!has_allocate_mem) {
        temp.Resize(output->dims());
        temp.mutable_data<T>(ctx.GetPlace());
        input_holder = &temp;
        has_allocate_mem = true;
      }
      if (!is_test) {
        prev_hidden_data = hidden_data.Slice(i - 1, i);
        input_holder->Resize(output->dims());
902 903 904 905 906 907 908 909
        if (dropout_prob != 0) {
          dropout_cpu_function_inplace<T>(ctx, &prev_hidden_data, input_holder,
                                          dropout_mask, dropout_prob, seed,
                                          is_test, &has_dropout_reset);
        } else {
          input_holder = &prev_hidden_data;
          input_holder->Resize(output->dims());
        }
910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955
      } else {
        SwapPoniter(&output_holder, &input_holder);
      }
    }
    const Tensor* input_temp_holder = input;
    if (i > 0) {
      input_temp_holder = input_holder;
    }
    LayerT<T, CellType>* layer;
    SingleLayerT<T, CellType> slayer(cell);
    BidirLayerT<T, CellType> blayer(cell);
    if (is_bidirec) {
      layer = &blayer;
    } else {
      layer = &slayer;
    }
    (*layer)(ctx, input_temp_holder, parameter_lists[i], init_h_unbind,
             init_c_unbind, sequence_length, last_h_unbind, last_c_unbind,
             output_holder, i, gate_num, &curr_gate_data, &curr_cell_data,
             &curr_cell_act_data, is_test);
  }
  if (num_layers % 2 == 0) {
    framework::TensorCopy(
        *output_holder, ctx.GetPlace(),
        ctx.template device_context<platform::CPUDeviceContext>(), output);
  }
}

template <typename DeviceContext, typename T>
class RNNCPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* input = ctx.Input<Tensor>("Input");
    auto pre_state = ctx.MultiInput<Tensor>("PreState");
    auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
    auto state = ctx.MultiOutput<Tensor>("State");
    auto* output = ctx.Output<Tensor>("Out");
    auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
    auto* reserve_data = ctx.Output<Tensor>("Reserve");
    const int& num_layers = ctx.Attr<int>("num_layers");
    const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
    const int& input_size = ctx.Attr<int>("input_size");
    const int& hidden_size = ctx.Attr<int>("hidden_size");
    const float& dropout_prob = ctx.Attr<float>("dropout_prob");
    const std::string& mode = ctx.Attr<std::string>("mode");
    const int& seed = ctx.Attr<int>("seed");
956
    bool is_test = ctx.HasAttr("is_test") ? ctx.Attr<bool>("is_test") : false;
957 958 959 960 961 962

    bool has_seq_length = ctx.HasInput("SequenceLength");
    const Tensor* sequence_length = nullptr;
    if (has_seq_length) {
      sequence_length = ctx.Input<Tensor>("SequenceLength");
    }
963 964
    if (dropout_mask->IsInitialized()) {
      if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
965
    }
966
    dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
967 968 969 970 971 972 973 974 975 976 977 978 979 980 981

    // init the output and allocate the memory
    output->mutable_data<T>(ctx.GetPlace());
    int gate_num = 4;
    state[0]->mutable_data<T>(ctx.GetPlace());
    if (is_lstm(ctx)) {
      state[1]->mutable_data<T>(ctx.GetPlace());
      RnnFunc<LSTMCell<T>, Layer, SingleLayer, BidirLayer, T>(
          ctx, input, weight_list, pre_state[0], pre_state[1], sequence_length,
          state[0], state[1], output, dropout_mask, num_layers, gate_num,
          input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
          seed, reserve_data);
    } else if (is_rnn_relu(ctx)) {
      gate_num = 1;
      RnnFunc<
982
          SimpleRNNCell<T, ReluCPUFunctor, math::detail::ActivationType::kReLU>,
983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143
          Layer, SingleLayer, BidirLayer, T>(
          ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
          state[0], nullptr, output, dropout_mask, num_layers, gate_num,
          input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
          seed, reserve_data);
    } else if (is_rnn_tanh(ctx)) {
      gate_num = 1;
      RnnFunc<
          SimpleRNNCell<T, TanhFunctor, math::detail::ActivationType::kTanhV2>,
          Layer, SingleLayer, BidirLayer, T>(
          ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
          state[0], nullptr, output, dropout_mask, num_layers, gate_num,
          input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
          seed, reserve_data);
    } else if (is_gru(ctx)) {
      gate_num = 3;
      RnnFunc<GRUCell<T>, Layer, SingleLayer, BidirLayer, T>(
          ctx, input, weight_list, pre_state[0], nullptr, sequence_length,
          state[0], nullptr, output, dropout_mask, num_layers, gate_num,
          input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test,
          seed, reserve_data);
    }
  }
};

template <typename T>
void create_lstm_value(math::LstmMetaValue<T>* lstm_value) {
  lstm_value->check_ig = nullptr;
  lstm_value->check_fg = nullptr;
  lstm_value->check_og = nullptr;
}

template <typename T>
void create_lstm_grad(math::LstmMetaGrad<T>* lstm_grad) {
  lstm_grad->check_ig_grad = nullptr;
  lstm_grad->check_fg_grad = nullptr;
  lstm_grad->check_og_grad = nullptr;
}

template <typename T>
void create_tensor_by_list(const framework::ExecutionContext& context,
                           Tensor* dst, const std::vector<T>& v) {
  int tensor_size = v.size();
  dst->Resize({tensor_size});
  dst->mutable_data<T>(context.GetPlace());
  int size = v.size();
  for (int i = 0; i < size; ++i) {
    dst->data<T>()[i] = v[i];
  }
}

template <typename T, typename GradCellType>
struct GradLayer {
  explicit GradLayer(const GradCellType& cell) : cell_(cell) {}
  virtual ~GradLayer() {}
  void run_rnn_grad_function(
      const framework::ExecutionContext& context,
      const platform::CPUDeviceContext& device_ctx, const Tensor* input,
      Tensor* input_grad, const Tensor* sequence_length,
      std::vector<Tensor>* init_h_unbind, std::vector<Tensor>* init_c_unbind,
      std::vector<Tensor>* init_h_grad_unbind,
      std::vector<Tensor>* init_c_grad_unbind, Tensor* layer_grad_gate_tensor,
      std::vector<Tensor>* layer_gate_tensor_unbind,
      std::vector<Tensor>* layer_grad_gate_tensor_unbind,
      std::vector<Tensor>* layer_state_tensor_unbind,
      std::vector<Tensor>* layer_act_state_tensor_unbind,
      std::vector<Tensor>* output_tensor_unbind,
      std::vector<Tensor>* output_grad_tensor_unbind,
      const TensorList& last_h_grad_unbind,
      const TensorList& last_c_grad_unbind,
      const std::vector<TensorList>& parameter_lists,
      std::vector<TensorList>* weight_list_grad, const int& layer_idx,
      const int& time_step, const bool& has_sequence_length,
      const bool& is_bidirec, const bool& is_reverse) {
    const int& direction_num = is_bidirec ? 2 : 1;
    const int& current_reverse_idx = is_reverse ? 1 : 0;
    const int& current_layer_idx =
        direction_num * layer_idx + current_reverse_idx;
    int begin_idx = 0;
    if (is_reverse) {
      begin_idx = time_step;
    }

    Tensor mask_matrix;
    TensorList mask_tensor_list;
    int mask_min_length = time_step;
    if (has_sequence_length) {
      mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]}));
      create_mask_matrix<T>(context, sequence_length, &mask_matrix, is_reverse,
                            &mask_min_length);
      mask_tensor_list = Unbind(mask_matrix);
    }
    // copy the last_h, last_c for swaping pointer
    Tensor a, b;
    Tensor* dynamic_grad_last_h = &a;
    Tensor* dynamic_grad_last_c = &b;
    dynamic_grad_last_h->Resize(last_h_grad_unbind[current_layer_idx].dims());
    dynamic_grad_last_h->mutable_data<T>(context.GetPlace());
    framework::TensorCopy(last_h_grad_unbind[current_layer_idx],
                          context.GetPlace(), dynamic_grad_last_h);
    if (last_c_grad_unbind.size() > 0) {
      dynamic_grad_last_c->Resize(last_c_grad_unbind[current_layer_idx].dims());
      dynamic_grad_last_c->mutable_data<T>(context.GetPlace());
      framework::TensorCopy(last_c_grad_unbind[current_layer_idx],
                            context.GetPlace(), dynamic_grad_last_c);
    } else {
      dynamic_grad_last_c = nullptr;
    }

    Tensor c, d;
    Tensor* dynamic_grad_pre_h = &c;
    Tensor* dynamic_grad_pre_c = &d;
    math::SetConstant<platform::CPUDeviceContext, T> zero;
    if (init_h_grad_unbind->size() > 0) {
      dynamic_grad_pre_h->ShareDataWith(
          (*init_h_grad_unbind)[current_layer_idx]);
    } else {
      dynamic_grad_pre_h->Resize(dynamic_grad_last_h->dims());
      dynamic_grad_pre_h->mutable_data<T>(context.GetPlace());
      zero(device_ctx, dynamic_grad_pre_h, static_cast<T>(0.0));
    }
    if (init_c_grad_unbind->size() > 0) {
      dynamic_grad_pre_c->ShareDataWith(
          (*init_c_grad_unbind)[current_layer_idx]);
    } else {
      if (is_lstm(context) || is_gru(context)) {
        dynamic_grad_pre_c->Resize(dynamic_grad_last_h->dims());
        dynamic_grad_pre_c->mutable_data<T>(context.GetPlace());
        if (is_gru(context)) {
          dynamic_grad_last_c = dynamic_grad_pre_c;
        }
      } else {
        dynamic_grad_pre_c = nullptr;
      }
    }

    if (is_reverse) {
      // must be reverse the input, output, input_grad, output_grad
      // the gate and grad_gate must be reverse
      std::reverse(layer_gate_tensor_unbind->begin(),
                   layer_gate_tensor_unbind->end());
      std::reverse(layer_grad_gate_tensor_unbind->begin(),
                   layer_grad_gate_tensor_unbind->end());
      /*
      if (has_sequence_length) {
        std::reverse(mask_tensor_list.begin(), mask_tensor_list.end());
      }*/
      std::reverse(output_tensor_unbind->begin(), output_tensor_unbind->end());
      std::reverse(output_grad_tensor_unbind->begin(),
                   output_grad_tensor_unbind->end());
    }

    Tensor* weight_grad =
        &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 1]);
    weight_grad->mutable_data<T>(context.GetPlace());
    zero(device_ctx, weight_grad, static_cast<T>(0.0));

    Tensor* pre_hidden = nullptr;
    Tensor* pre_state = nullptr;
    Tensor* hidden = nullptr;
    if (is_gru(context)) {
1144 1145 1146
      zero(device_ctx,
           &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
           static_cast<T>(0.0));
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
    }
    for (int i = time_step - 1; i >= 0; --i) {
      if (has_sequence_length) {
        this->mask_preprocess(context, &(*output_grad_tensor_unbind)[i],
                              dynamic_grad_last_h, dynamic_grad_last_c,
                              dynamic_grad_pre_h, dynamic_grad_pre_c,
                              mask_tensor_list[i]);
      } else {
        this->preprocess(context, &(*output_grad_tensor_unbind)[i],
                         dynamic_grad_last_h);
      }
      hidden = &(*output_tensor_unbind)[i];
      if (i == 0) {
        pre_hidden = &(*init_h_unbind)[current_layer_idx];
        if (init_c_unbind->size() > 0) {
          pre_state = &(*init_c_unbind)[current_layer_idx];
        }
      } else {
        pre_hidden = &(*output_tensor_unbind)[i - 1];
        if (layer_state_tensor_unbind->size() > 0) {
          pre_state = &(*layer_state_tensor_unbind)[begin_idx + i - 1];
        }
      }
      this->cell_(
          context, &(*layer_gate_tensor_unbind)[i],
          &(*layer_state_tensor_unbind)[begin_idx + i],
          &(*layer_act_state_tensor_unbind)[begin_idx + i], hidden,
          &(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]),
          pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c,
          &(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h,
1177
          dynamic_grad_pre_c,
1178 1179 1180 1181 1182 1183 1184 1185
          &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
          mask_tensor_list[i], has_sequence_length);
      SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h);
      SwapPoniter(&dynamic_grad_last_c, &dynamic_grad_pre_c);
    }
    // postproces for gradient for w_hi, X, bias_hi, bias_hh
    this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad,
                      parameter_lists[layer_idx],
1186
                      &((*weight_list_grad)[layer_idx]), is_reverse);
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211

    // copy the gradient to init_c init_h
    if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) {
      framework::TensorCopy(*dynamic_grad_last_h, context.GetPlace(),
                            &((*init_h_grad_unbind)[current_layer_idx]));
    }
    if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) {
      framework::TensorCopy(*dynamic_grad_last_c, context.GetPlace(),
                            &((*init_c_grad_unbind)[current_layer_idx]));
    }
  }

  virtual void operator()(
      const framework::ExecutionContext& context, const Tensor* input,
      const Tensor* output, const TensorList& init_h_unbind,
      const TensorList& init_c_unbind, const TensorList& last_h_grad_unbind,
      const TensorList& last_c_grad_unbind,
      const TensorList& gate_tensor_unbind,
      const TensorList& state_tensor_unbind,
      const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
      const std::vector<TensorList>& parameter_lists,
      const Tensor* sequence_length, Tensor* input_grad,
      TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
      const std::vector<TensorList>& weight_list_grad, const int& layer_idx,
      const int& gate_num) {}
1212

1213 1214 1215 1216
  void preprocess(const framework::ExecutionContext& context,
                  const Tensor* grad_output, Tensor* grad_last_h) {
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
1217
    auto output_grad = framework::EigenMatrix<T>::Reshape(
1218
        *grad_output, grad_output->dims().size() - 1);
1219
    auto last_h_grad = framework::EigenMatrix<T>::Reshape(
1220 1221
        *grad_last_h, grad_last_h->dims().size() - 1);
    // the output gradient contribute the gradient to last_h
1222
    last_h_grad.device(place) = last_h_grad + output_grad;
1223 1224 1225 1226 1227 1228 1229 1230
  }

  void mask_preprocess(const framework::ExecutionContext& context,
                       const Tensor* grad_output, Tensor* grad_last_h,
                       Tensor* grad_last_c, Tensor* grad_pre_h,
                       Tensor* grad_pre_c, const Tensor& mask_tensor) {
    auto& place = *context.template device_context<platform::CPUDeviceContext>()
                       .eigen_device();
1231
    auto mask = framework::EigenMatrix<T>::From(
1232
        mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
1233 1234
    auto mask_broadcast =
        mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2]));
1235

1236
    auto last_h_grad = framework::EigenMatrix<T>::Reshape(
1237
        *grad_last_h, grad_last_h->dims().size() - 1);
1238
    auto pre_h_grad = framework::EigenMatrix<T>::Reshape(
1239
        *grad_pre_h, grad_pre_h->dims().size() - 1);
1240
    auto output_grad = framework::EigenMatrix<T>::Reshape(
1241
        *grad_output, grad_output->dims().size() - 1);
1242 1243 1244
    last_h_grad.device(place) = last_h_grad + output_grad * mask_broadcast;
    pre_h_grad.device(place) = (1 - mask_broadcast) * last_h_grad;
    last_h_grad.device(place) = mask_broadcast * last_h_grad;
1245 1246

    if (grad_last_c && grad_pre_c && is_lstm(context)) {
1247
      auto last_c_grad = framework::EigenMatrix<T>::Reshape(
1248
          *grad_last_c, grad_last_c->dims().size() - 1);
1249
      auto pre_c_grad = framework::EigenMatrix<T>::Reshape(
1250
          *grad_pre_c, grad_pre_c->dims().size() - 1);
1251 1252
      pre_c_grad.device(place) = (1 - mask_broadcast) * last_c_grad;
      last_c_grad.device(place) = mask_broadcast * last_c_grad;
1253 1254 1255 1256 1257 1258
    }
  }

  void postprocess(const framework::ExecutionContext& context,
                   const Tensor& grad_gate, const Tensor& input,
                   Tensor* input_grad, const TensorList& parameters,
1259
                   TensorList* grad_parameters, const int& is_reverse) {
1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299
    // we get the grad_gate step by step, and need to bradocast the grad to the
    // grad_w_hi, grad_bias_hi, grad_bias_hh
    int begin_idx = 0;
    if (is_reverse) {
      begin_idx = 4;
    }
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);

    // calc the gradient for the w_hi
    auto mat_dim_out_grad =
        math::CreateMatrixDescriptor(grad_gate.dims(), 0, true);
    auto mat_dim_input = math::CreateMatrixDescriptor(input.dims(), 0, false);
    mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_;
    mat_dim_out_grad.batch_size_ = 0;
    mat_dim_input.height_ *= mat_dim_input.batch_size_;
    mat_dim_input.batch_size_ = 0;
    blas.MatMul(grad_gate, mat_dim_out_grad, input, mat_dim_input,
                static_cast<T>(1.0), &((*grad_parameters)[begin_idx + 0]),
                T(0));

    // calc the gradient for the X
    auto mat_dim_out_grad_new =
        math::CreateMatrixDescriptor(grad_gate.dims(), 0, false);
    mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_;
    mat_dim_out_grad_new.batch_size_ = 0;
    auto mat_dim_parameter =
        math::CreateMatrixDescriptor(parameters[0].dims(), 0, false);
    blas.MatMul(grad_gate, mat_dim_out_grad_new, parameters[begin_idx + 0],
                mat_dim_parameter, static_cast<T>(1.0), input_grad, T(1));

    // calc the gradient of Bias_hi, Bias_hh
    math::ColwiseSum<platform::CPUDeviceContext, T> col_sum;
    Tensor tmp_grad_gate;
    tmp_grad_gate.ShareDataWith(grad_gate);
    tmp_grad_gate.Resize(
        {grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]});
    col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2]));
    // Bias_hh
1300
    if (!is_gru(context)) {
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536
      col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3]));
    }
  }
  GradCellType cell_;
};

template <typename T, typename GradCellType>
struct SingleGradLayer : GradLayer<T, GradCellType> {
  // explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {}
  explicit SingleGradLayer(const GradCellType& cell)
      : GradLayer<T, GradCellType>(cell) {}
  virtual ~SingleGradLayer() {}
  void operator()(
      const framework::ExecutionContext& context, const Tensor* input,
      const Tensor* output, std::vector<Tensor>* init_h_unbind,
      std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
      const TensorList& last_c_grad_unbind,
      const TensorList& gate_tensor_unbind,
      const TensorList& state_tensor_unbind,
      const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
      const std::vector<TensorList>& parameter_lists,
      const Tensor* sequence_length, Tensor* input_grad,
      TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
      std::vector<TensorList>* weight_list_grad, const int& layer_idx,
      const int& gate_num) {
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    math::SetConstant<platform::CPUDeviceContext, T> zero;
    zero(device_ctx, input_grad, static_cast<T>(0.0));

    const bool& is_bidirec = context.Attr<bool>("is_bidirec");
    const int& time_step = input->dims()[0];
    const int& batch_size = input->dims()[1];
    const int& direction_num = is_bidirec ? 2 : 1;
    const int& hidden_size = context.Attr<int>("hidden_size");

    // in this section, create the gate_state_grad for the postprocess calculate
    // ubind the output, the output from [time_step, batch_size, hidden_size]
    auto output_tensor_unbind = Unbind(*output);
    auto output_grad_tensor_unbind = Unbind(*output_grad);
    auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
    layer_gate_tensor.Resize(
        {time_step * direction_num, batch_size, hidden_size * gate_num});
    auto layer_gate_tensor_unbind = Unbind(layer_gate_tensor);
    // the gate_tensor and the grad_gate_tensor must be unbind
    Tensor layer_grad_gate_tensor;
    layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
    layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
    auto layer_grad_gate_tensor_unbind = Unbind(layer_grad_gate_tensor);

    Tensor layer_state_tensor;
    TensorList layer_state_tensor_unbind;
    if (state_tensor_unbind.size() > 0) {
      layer_state_tensor = state_tensor_unbind[layer_idx];
      layer_state_tensor.Resize(
          {time_step * direction_num, batch_size, hidden_size});
      layer_state_tensor_unbind = Unbind(layer_state_tensor);
    }

    Tensor layer_act_state_tensor;
    TensorList layer_act_state_tensor_unbind;
    if (act_state_tensor_unbind.size() > 0) {
      layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
      layer_act_state_tensor.Resize(
          {time_step * direction_num, batch_size, hidden_size});
      layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
    }
    const bool& has_sequence_length = sequence_length == nullptr ? false : true;
    this->run_rnn_grad_function(
        context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
        init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
        &layer_grad_gate_tensor, &layer_gate_tensor_unbind,
        &layer_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
        &layer_act_state_tensor_unbind, &output_tensor_unbind,
        &output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind,
        parameter_lists, weight_list_grad, layer_idx, time_step,
        has_sequence_length, is_bidirec, false);
  }
};
template <typename T>
void split_tensor_at_last_dim(const framework::ExecutionContext& context,
                              const platform::CPUDeviceContext& dev_ctx,
                              const Tensor* output,
                              std::vector<Tensor*>* output_vec,
                              const int& axis) {
  std::vector<const framework::Tensor*> shape_refer;
  (*output_vec)[0]->Resize(
      {output->dims()[0], output->dims()[1], output->dims()[2] / 2});
  (*output_vec)[0]->mutable_data<T>(context.GetPlace());
  (*output_vec)[1]->Resize(
      {output->dims()[0], output->dims()[1], output->dims()[2] / 2});
  (*output_vec)[1]->mutable_data<T>(context.GetPlace());
  shape_refer.emplace_back((*output_vec)[0]);
  shape_refer.emplace_back((*output_vec)[1]);
  math::SplitFunctor<platform::CPUDeviceContext, T> functor;
  functor(dev_ctx, *output, shape_refer, axis, output_vec);
}

template <typename T, typename GradCellType>
struct BidirGradLayer : GradLayer<T, GradCellType> {
  explicit BidirGradLayer(const GradCellType& cell)
      : GradLayer<T, GradCellType>(cell) {}
  virtual ~BidirGradLayer() {}
  void operator()(
      const framework::ExecutionContext& context, const Tensor* input,
      const Tensor* output, std::vector<Tensor>* init_h_unbind,
      std::vector<Tensor>* init_c_unbind, const TensorList& last_h_grad_unbind,
      const TensorList& last_c_grad_unbind,
      const TensorList& gate_tensor_unbind,
      const TensorList& state_tensor_unbind,
      const TensorList& act_state_tensor_unbind, const Tensor* output_grad,
      const std::vector<TensorList>& parameter_lists,
      const Tensor* sequence_length, Tensor* input_grad,
      TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
      std::vector<TensorList>* weight_list_grad, const int& layer_idx,
      const int& gate_num) {
    const bool& is_bidirec = context.Attr<bool>("is_bidirec");
    const int& time_step = input->dims()[0];
    const int& batch_size = input->dims()[1];
    const int& direction_num = is_bidirec ? 2 : 1;
    const int& hidden_size = context.Attr<int>("hidden_size");
    // split the output two tensor to output_forward, output_backward
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    math::SetConstant<platform::CPUDeviceContext, T> zero;
    zero(device_ctx, input_grad, static_cast<T>(0.0));

    std::vector<Tensor*> output_vec;
    Tensor forward_output;
    Tensor backward_output;
    std::vector<Tensor> forward_output_tensor_unbind;
    std::vector<Tensor> backward_output_tensor_unbind;
    // in the last layer, we will use the output as the last hidden
    // the output just the concat the forward hidden, backward hidden, so just
    // split it
    // in other layer, we just split the hidden in the rows
    output_vec.emplace_back(&forward_output);
    output_vec.emplace_back(&backward_output);
    split_tensor_at_last_dim<T>(context, device_ctx, output, &output_vec, 2);
    forward_output_tensor_unbind = Unbind(*(output_vec[0]));
    backward_output_tensor_unbind = Unbind(*(output_vec[1]));

    std::vector<Tensor*> output_grad_vec;
    Tensor grad_forward_output;
    Tensor grad_backward_output;
    output_grad_vec.emplace_back(&grad_forward_output);
    output_grad_vec.emplace_back(&grad_backward_output);
    split_tensor_at_last_dim<T>(context, device_ctx, output_grad,
                                &output_grad_vec, 2);
    auto forward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[0]));
    auto backward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[1]));

    // the gate_tensor and the grad_gate_tensor must be unbind
    auto layer_gate_tensor = gate_tensor_unbind[layer_idx];
    layer_gate_tensor.Resize(
        {time_step * 2, batch_size, hidden_size * gate_num});
    auto layer_forward_gate_tensor = layer_gate_tensor.Slice(0, time_step);
    auto layer_backward_gate_tensor =
        layer_gate_tensor.Slice(time_step, 2 * time_step);
    auto layer_forward_gate_tensor_unbind = Unbind(layer_forward_gate_tensor);
    auto layer_backward_gate_tensor_unbind = Unbind(layer_backward_gate_tensor);

    Tensor layer_grad_gate_tensor;
    layer_grad_gate_tensor.Resize(layer_gate_tensor.dims());
    layer_grad_gate_tensor.mutable_data<T>(context.GetPlace());
    zero(device_ctx, &layer_grad_gate_tensor, static_cast<T>(0.0));
    auto layer_forward_grad_gate_tensor =
        layer_grad_gate_tensor.Slice(0, time_step);
    auto layer_backward_grad_gate_tensor =
        layer_grad_gate_tensor.Slice(time_step, 2 * time_step);
    auto layer_forward_grad_gate_tensor_unbind =
        Unbind(layer_forward_grad_gate_tensor);
    auto layer_backward_grad_gate_tensor_unbind =
        Unbind(layer_backward_grad_gate_tensor);

    Tensor layer_state_tensor;
    TensorList layer_state_tensor_unbind;
    if (state_tensor_unbind.size() > 0) {
      layer_state_tensor = state_tensor_unbind[layer_idx];
      layer_state_tensor.Resize(
          {time_step * direction_num, batch_size, hidden_size});
      layer_state_tensor_unbind = Unbind(layer_state_tensor);
    }

    Tensor layer_act_state_tensor;
    TensorList layer_act_state_tensor_unbind;
    if (act_state_tensor_unbind.size() > 0) {
      layer_act_state_tensor = act_state_tensor_unbind[layer_idx];
      layer_act_state_tensor.Resize(
          {time_step * direction_num, batch_size, hidden_size});
      layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor);
    }
    const bool& has_sequence_length = sequence_length == nullptr ? false : true;

    this->run_rnn_grad_function(
        context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
        init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
        &layer_forward_grad_gate_tensor, &layer_forward_gate_tensor_unbind,
        &layer_forward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
        &layer_act_state_tensor_unbind, &forward_output_tensor_unbind,
        &forward_output_grad_tensor_unbind, last_h_grad_unbind,
        last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
        time_step, has_sequence_length, is_bidirec, false);

    this->run_rnn_grad_function(
        context, device_ctx, input, input_grad, sequence_length, init_h_unbind,
        init_c_unbind, init_h_grad_unbind, init_c_grad_unbind,
        &layer_backward_grad_gate_tensor, &layer_backward_gate_tensor_unbind,
        &layer_backward_grad_gate_tensor_unbind, &layer_state_tensor_unbind,
        &layer_act_state_tensor_unbind, &backward_output_tensor_unbind,
        &backward_output_grad_tensor_unbind, last_h_grad_unbind,
        last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx,
        time_step, has_sequence_length, is_bidirec, true);
  }
};

template <typename T>
void backup_tensor(const framework::ExecutionContext& context, Tensor* dst,
                   Tensor* src) {
  auto& device_ctx =
      context.template device_context<platform::CPUDeviceContext>();
  dst->Resize(src->dims());
  dst->mutable_data<T>(context.GetPlace());
  framework::TensorCopy(*src, device_ctx.GetPlace(), device_ctx, dst);
}

template <typename T>
struct GradCell {
  virtual ~GradCell() {}
  virtual void operator()(const framework::ExecutionContext& context,
                          Tensor* gate_tensor, Tensor* state_tensor,
                          Tensor* act_state_tensor, Tensor* hidden_tensor,
                          const Tensor* weight_hh, Tensor* pre_hidden,
                          Tensor* pre_state, Tensor* grad_hidden,
                          Tensor* grad_state, Tensor* grad_gate,
                          Tensor* grad_weight_hh, Tensor* grad_pre_hidden,
1537 1538
                          Tensor* grad_pre_state, Tensor* grad_bias_hh,
                          const Tensor& mask_tensor,
1539
                          bool has_sequence_length) const {}
1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574

  void postprocess_pre_hidden_grad(const framework::ExecutionContext& context,
                                   Tensor* grad_pre_hidden,
                                   Tensor* grad_pre_hidden_bak,
                                   Tensor* grad_pre_state,
                                   Tensor* grad_pre_state_bak,
                                   const Tensor& mask_tensor,
                                   bool has_sequence_length) const {
    if (has_sequence_length) {
      auto& place =
          *context.template device_context<platform::CPUDeviceContext>()
               .eigen_device();
      auto mask = framework::EigenMatrix<T>::From(
          mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
      auto mask_broadcast =
          mask.broadcast(Eigen::DSizes<int, 2>(1, grad_pre_hidden->dims()[2]));
      auto pre_hidden_grad = framework::EigenMatrix<T>::Reshape(
          *grad_pre_hidden, grad_pre_hidden->dims().size() - 1);
      auto pre_hidden_bak_grad = framework::EigenMatrix<T>::Reshape(
          *grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1);
      pre_hidden_grad.device(place) =
          (1 - mask_broadcast) * pre_hidden_bak_grad +
          pre_hidden_grad * mask_broadcast;
      if (grad_pre_state) {
        auto pre_state_grad = framework::EigenMatrix<T>::Reshape(
            *grad_pre_state, grad_pre_state->dims().size() - 1);
        auto pre_state_bak_grad = framework::EigenMatrix<T>::Reshape(
            *grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1);
        pre_state_grad.device(place) =
            (1 - mask_broadcast) * pre_state_bak_grad +
            pre_state_grad * mask_broadcast;
      }
    }
  }

1575 1576 1577 1578
  virtual void update_pre_hidden_grad(
      const framework::ExecutionContext& context, Tensor* grad_gate,
      const Tensor* weight_hh, Tensor* grad_pre_hidden,
      Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state,
1579 1580
      Tensor* grad_pre_state_bak, const Tensor& mask_tensor,
      bool has_sequence_length) const {
1581 1582 1583 1584 1585 1586 1587 1588 1589 1590
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
    Tensor* grad_gate_tmp = grad_gate;
    auto mat_dim_a =
        math::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false);
    mat_dim_a.height_ *= mat_dim_a.batch_size_;
    mat_dim_a.batch_size_ = 0;
    auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, false);
    blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b,
1591 1592 1593 1594
                static_cast<T>(1.0), grad_pre_hidden, 0);
    postprocess_pre_hidden_grad(context, grad_pre_hidden, grad_pre_hidden_bak,
                                grad_pre_state, grad_pre_state_bak, mask_tensor,
                                has_sequence_length);
1595 1596 1597 1598
  }

  virtual void update_weight_hh_grad(const framework::ExecutionContext& context,
                                     Tensor* grad_gate, Tensor* pre_hidden,
1599
                                     Tensor* grad_weight_hh) const {
1600 1601 1602 1603 1604 1605 1606 1607 1608
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
    auto mat_dim_c = math::CreateMatrixDescriptor(grad_gate->dims(), 0, true);
    mat_dim_c.height_ *= mat_dim_c.batch_size_;
    mat_dim_c.batch_size_ = 0;
    auto mat_dim_d = math::CreateMatrixDescriptor(pre_hidden->dims(), 0, false);
    mat_dim_d.height_ *= mat_dim_d.batch_size_;
    mat_dim_d.batch_size_ = 0;
1609
    blas.MatMul(*grad_gate, mat_dim_c, *pre_hidden, mat_dim_d,
1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622
                static_cast<T>(1.0), grad_weight_hh, static_cast<T>(1.0));
  }
};

template <typename T, template <typename> class EigenActivationBackwardFunctor>
struct SimpleRNNGradCell : GradCell<T> {
  void operator()(const framework::ExecutionContext& context,
                  Tensor* gate_tensor, Tensor* state_tensor,
                  Tensor* act_state_tensor, Tensor* hidden_tensor,
                  const Tensor* weight_hh, Tensor* pre_hidden,
                  Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
                  Tensor* grad_gate, Tensor* grad_weight_hh,
                  Tensor* grad_pre_hidden, Tensor* grad_pre_state,
1623
                  Tensor* grad_bias_hh, const Tensor& mask_tensor,
1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647
                  bool has_sequence_length) const override {
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    Tensor grad_pre_hidden_bak;
    if (has_sequence_length) {
      backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
    }
    // h = act(z)
    // update dz
    auto dz = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(grad_gate, "Output", "dz", "Grad"));
    auto dh = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(grad_hidden, "Input", "dh", "Grad"));
    auto h = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(hidden_tensor, "Input", "h", "Value"));
    // useless, but need this argument to execute functor
    auto z = EigenVector<T>::Flatten(
        GET_DATA_SAFELY(gate_tensor, "Input", "z", "Value"));

    auto* place = device_ctx.eigen_device();
    EigenActivationBackwardFunctor<T> functor;
    functor(*place, z, h, dh, dz);

    // update grad_weight_hh, grad_pre_hidden
1648 1649 1650 1651
    this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden,
                                 &grad_pre_hidden_bak, nullptr, nullptr,
                                 mask_tensor, has_sequence_length);
    this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663
  }
};

template <typename T>
struct GRUGradCell : GradCell<T> {
  void operator()(const framework::ExecutionContext& context,
                  Tensor* gate_tensor, Tensor* state_tensor,
                  Tensor* act_state_tensor, Tensor* hidden_tensor,
                  const Tensor* weight_hh, Tensor* pre_hidden,
                  Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
                  Tensor* grad_gate, Tensor* grad_weight_hh,
                  Tensor* grad_pre_hidden, Tensor* grad_pre_state,
1664
                  Tensor* grad_bias_hh, const Tensor& mask_tensor,
1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681
                  bool has_sequence_length) const override {
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    size_t frame_size = pre_hidden->dims()[2];
    size_t batch_size = pre_hidden->dims()[1];
    Tensor grad_pre_hidden_bak;
    if (has_sequence_length) {
      backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
    }
    // zero pre_hidden
    math::SetConstant<platform::CPUDeviceContext, T> zero;
    zero(device_ctx, grad_pre_hidden, static_cast<T>(0.0));
    math::GRUMetaValue<T> gru_value;
    math::GRUMetaGrad<T> gru_grad;
    gru_value.gate_value = gate_tensor->data<T>();
    gru_value.prev_out_value = pre_hidden->data<T>();
    gru_value.reset_output_value = state_tensor->data<T>();
1682 1683
    gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
    gru_value.gate_weight = weight_hh->data<T>();
1684 1685 1686 1687 1688 1689 1690 1691

    gru_grad.gate_grad = grad_gate->data<T>();
    gru_grad.reset_output_grad = grad_state->data<T>();
    gru_grad.prev_out_grad = grad_pre_hidden->data<T>();
    gru_grad.output_grad = grad_hidden->data<T>();
    gru_grad.gate_weight_grad = grad_weight_hh->data<T>();
    gru_grad.state_weight_grad =
        grad_weight_hh->data<T>() + 2 * frame_size * frame_size;
1692
    gru_grad.bias_hh_grad = grad_bias_hh->data<T>();
1693 1694 1695 1696 1697 1698 1699

    auto act_gate = math::detail::GetActivationType("sigmoid_v2");
    auto act_node = math::detail::GetActivationType("tanh_v2");
    math::GRUUnitGradFunctorV2<platform::CPUDeviceContext, T>::compute(
        device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node,
        act_gate);

1700 1701 1702
    this->postprocess_pre_hidden_grad(context, grad_pre_hidden,
                                      &grad_pre_hidden_bak, nullptr, nullptr,
                                      mask_tensor, has_sequence_length);
1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714
  }
};

template <typename T>
struct LSTMGradCell : GradCell<T> {
  void operator()(const framework::ExecutionContext& context,
                  Tensor* gate_tensor, Tensor* state_tensor,
                  Tensor* act_state_tensor, Tensor* hidden_tensor,
                  const Tensor* weight_hh, Tensor* pre_hidden,
                  Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
                  Tensor* grad_gate, Tensor* grad_weight_hh,
                  Tensor* grad_pre_hidden, Tensor* grad_pre_state,
1715
                  Tensor* grad_bias_hh, const Tensor& mask_tensor,
1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753
                  bool has_sequence_length) const override {
    auto& device_ctx =
        context.template device_context<platform::CPUDeviceContext>();
    size_t frame_size = state_tensor->dims()[2];
    size_t batch_size = state_tensor->dims()[1];

    Tensor grad_pre_hidden_bak;
    Tensor grad_pre_state_bak;
    if (has_sequence_length) {
      backup_tensor<T>(context, &grad_pre_hidden_bak, grad_pre_hidden);
      backup_tensor<T>(context, &grad_pre_state_bak, grad_pre_state);
    }

    math::LstmMetaValue<T> lstm_value;
    math::LstmMetaGrad<T> lstm_grad;
    create_lstm_value(&lstm_value);
    create_lstm_grad(&lstm_grad);
    lstm_value.gate_value = gate_tensor->data<T>();
    lstm_value.state_value = state_tensor->data<T>();
    lstm_value.state_active_value = act_state_tensor->data<T>();
    lstm_value.prev_state_value = pre_state->data<T>();

    lstm_grad.state_grad = grad_state->data<T>();
    lstm_grad.gate_grad = grad_gate->data<T>();
    lstm_grad.output_grad = grad_hidden->data<T>();
    lstm_grad.prev_state_grad = grad_pre_state->data<T>();

    lstm_value.output_value = nullptr;
    lstm_grad.state_active_grad = nullptr;

    auto gate_act = math::detail::GetActivationType("sigmoid_v2");
    auto state_act = math::detail::GetActivationType("tanh_v2");
    auto cand_act = math::detail::GetActivationType("tanh_v2");

    T cell_clip = 0.0;
    math::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute(
        device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip,
        gate_act, state_act, cand_act, false);
1754 1755 1756 1757
    this->update_pre_hidden_grad(
        context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak,
        grad_pre_state, &grad_pre_state_bak, mask_tensor, has_sequence_length);
    this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811
  }
};

template <typename GradCellType,
          template <typename, typename> class SingleGradLayerT,
          template <typename, typename> class BidirGradLayerT, typename T>
void RnnGradFunc(const framework::ExecutionContext& context,
                 const int& gate_num) {
  // get the tensor pointer for the input
  auto* input = context.Input<Tensor>("Input");
  auto weight_list = context.MultiInput<Tensor>("WeightList");
  auto pre_state = context.MultiInput<Tensor>("PreState");

  const Tensor* init_h = pre_state[0];
  const Tensor* init_c = nullptr;
  if (is_lstm(context)) {
    init_c = pre_state[1];
  }
  auto* reserve_state = context.Input<Tensor>("Reserve");
  auto* dropout_state = context.Input<Tensor>("DropoutState");
  auto* output = context.Input<Tensor>("Out");
  auto* output_grad = context.Input<Tensor>(framework::GradVarName("Out"));
  auto state_grad = context.MultiInput<Tensor>(framework::GradVarName("State"));
  const Tensor* last_h_grad = state_grad[0];
  const Tensor* last_c_grad = nullptr;
  if (is_lstm(context)) {
    last_c_grad = state_grad[1];
  }

  bool has_seq_length = context.HasInput("SequenceLength");
  const Tensor* sequence_length = nullptr;
  if (has_seq_length) {
    sequence_length = context.Input<Tensor>("SequenceLength");
  }

  // get the tensor pointer for the output
  auto* input_grad = context.Output<Tensor>(framework::GradVarName("Input"));
  auto weight_grad_list = context.MultiOutput<framework::Tensor>(
      framework::GradVarName("WeightList"));
  auto pre_state_grad =
      context.MultiOutput<Tensor>(framework::GradVarName("PreState"));
  Tensor* init_h_grad = nullptr;
  Tensor* init_c_grad = nullptr;
  if (pre_state_grad.size() > 0) {  // has gradient
    init_h_grad = pre_state_grad[0];
    if (is_lstm(context)) {
      init_c_grad = pre_state_grad[1];
    }
  }

  // get the attributes for the calcluate
  const int& num_layers = context.Attr<int>("num_layers");
  const bool& is_bidirec = context.Attr<bool>("is_bidirec");
  const float& dropout_prob = context.Attr<float>("dropout_prob");
1812 1813
  bool is_test =
      context.HasAttr("is_test") ? context.Attr<bool>("is_test") : false;
1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931

  // get the input_size, batch_size, time_step, hidden_size
  const int& time_step = input->dims()[0];
  const int& batch_size = input->dims()[1];
  const int& hidden_size = context.Attr<int>("hidden_size");
  const int& direction_num = is_bidirec ? 2 : 1;
  // allocate the memory and initization the input_grad
  Tensor input_grad_value;
  if (!input_grad) {
    input_grad = &input_grad_value;
  }
  input_grad->mutable_data<T>(input->dims(), context.GetPlace());

  if (init_h_grad) {
    init_h_grad->mutable_data<T>(init_h->dims(), context.GetPlace());
  }
  if (init_c_grad) {
    init_c_grad->mutable_data<T>(init_c->dims(), context.GetPlace());
  }

  // reset the parameter to sorted order and allocate the memory
  std::vector<TensorList> parameter_lists;
  parameter_lists.reserve(num_layers);
  reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec,
                         &parameter_lists);

  for (unsigned int i = 0; i < weight_grad_list.size(); ++i) {
    weight_grad_list[i]->mutable_data<T>(context.GetPlace());
  }
  std::vector<TensorList> parameter_lists_grad;
  parameter_lists_grad.reserve(num_layers);
  reset_parameter_vector(weight_grad_list, num_layers, gate_num, is_bidirec,
                         &parameter_lists_grad);

  // resolve the state of reverse_state
  Tensor gate_tensor;
  Tensor state_tensor;
  Tensor act_state_tensor;
  Tensor hidden_tensor;
  SplitReserveData(context, reserve_state, &gate_tensor, &state_tensor,
                   &act_state_tensor, &hidden_tensor, direction_num, time_step,
                   batch_size, hidden_size, gate_num, num_layers);
  int gate_num_tmp = gate_num;
  if (gate_num == 0) {
    gate_num_tmp = 1;
  }
  gate_tensor.Resize({num_layers, time_step * direction_num, batch_size,
                      hidden_size * gate_num_tmp});
  if (state_tensor.numel() > 0) {
    state_tensor.Resize(
        {num_layers, time_step * direction_num, batch_size, hidden_size});
  }
  if (act_state_tensor.numel() > 0) {
    act_state_tensor.Resize(
        {num_layers, time_step * direction_num, batch_size, hidden_size});
  }
  if (num_layers > 1) {
    hidden_tensor.Resize(
        {num_layers - 1, time_step, batch_size, hidden_size * direction_num});
  }
  // unbind
  auto last_h_grad_unbind = Unbind(*last_h_grad);
  auto gate_tensor_unbind = Unbind(gate_tensor);
  TensorList last_c_grad_unbind;
  if (last_c_grad) {
    last_c_grad_unbind = Unbind(*last_c_grad);
  }

  TensorList init_h_unbind, init_c_unbind;
  TensorList init_h_grad_unbind, init_c_grad_unbind;
  TensorList state_tensor_unbind, act_state_tensor_unbind;
  TensorList hidden_tensor_unbind;

  init_h_unbind = Unbind(*init_h);
  if (init_c) {
    init_c_unbind = Unbind(*init_c);
  }

  if (init_h_grad != nullptr) {
    init_h_grad_unbind = Unbind(*init_h_grad);
  }
  if (init_c_grad != nullptr) {
    init_c_grad_unbind = Unbind(*init_c_grad);
  }
  if (state_tensor.numel() > 0) {
    state_tensor_unbind = Unbind(state_tensor);
  }
  if (act_state_tensor.numel() > 0) {
    act_state_tensor_unbind = Unbind(act_state_tensor);
  }
  if (num_layers > 1) {
    hidden_tensor_unbind = Unbind(hidden_tensor);
  }
  // squeeze the hidden first dim
  for (unsigned int i = 0; i < hidden_tensor_unbind.size(); i++) {
    hidden_tensor_unbind[i].Resize(
        framework::slice_ddim(hidden_tensor_unbind[i].dims(), 1,
                              hidden_tensor_unbind[i].dims().size()));
  }
  // add the output tensor to the hidden vector
  Tensor tmp;
  hidden_tensor_unbind.emplace_back(tmp);
  hidden_tensor_unbind[num_layers - 1].ShareDataWith(*output);

  GradCellType cell;
  Tensor layer_input;
  Tensor layer_output;
  Tensor* layer_input_grad_holder = nullptr;
  Tensor tmp_out;
  tmp_out.ShareDataWith(*output_grad);
  Tensor* layer_output_grad_holder = &tmp_out;
  Tensor input_grad_temp;
  Tensor output_grad_temp;

  bool has_allocate_mem = false;
  for (int i = num_layers - 1; i >= 0; --i) {
    // the layer input output had saved, just use the data
    if (i > 0) {
1932 1933 1934 1935 1936 1937
      if (layer_input.numel() == 0) {
        layer_input.Resize(hidden_tensor_unbind[i - 1].dims());
        layer_input.mutable_data<T>(context.GetPlace());
      }
      dropout_helper<T>(context, &hidden_tensor_unbind[i - 1], &layer_input,
                        dropout_state, dropout_prob);
1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018
    } else {
      layer_input.ShareDataWith(*input);
    }
    layer_output.ShareDataWith(hidden_tensor_unbind[i]);
    if (num_layers == 1) {
      layer_input_grad_holder = input_grad;
    } else {
      if (i == num_layers - 1) {
        input_grad_temp.Resize(layer_input.dims());
        input_grad_temp.mutable_data<T>(context.GetPlace());
        layer_input_grad_holder = &input_grad_temp;
      }
    }
    if (is_bidirec) {
      BidirGradLayerT<T, GradCellType> layer(cell);
      layer(context, &layer_input, &layer_output, &init_h_unbind,
            &init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
            gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
            layer_output_grad_holder, parameter_lists, sequence_length,
            layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
            &parameter_lists_grad, i, gate_num_tmp);
    } else {
      SingleGradLayerT<T, GradCellType> layer(cell);
      layer(context, &layer_input, &layer_output, &init_h_unbind,
            &init_c_unbind, last_h_grad_unbind, last_c_grad_unbind,
            gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind,
            layer_output_grad_holder, parameter_lists, sequence_length,
            layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind,
            &parameter_lists_grad, i, gate_num_tmp);
    }

    // calcluate the dropout gradient for the layer_input_grad_holder
    // dropout_state save in the forward process
    if (i > 0) {
      if ((!is_test) && (dropout_prob != 0)) {
        dropout_cpu_grad_function_inplace<T>(context, layer_input_grad_holder,
                                             dropout_state, dropout_prob);
      }
    }

    if (i - 1 == 0) {
      layer_output_grad_holder = input_grad;
    } else {
      if (!has_allocate_mem) {
        output_grad_temp.Resize(layer_input_grad_holder->dims());
        output_grad_temp.mutable_data<T>(context.GetPlace());
        layer_output_grad_holder = &output_grad_temp;
        has_allocate_mem = true;
      }
    }
    SwapPoniter(&layer_input_grad_holder, &layer_output_grad_holder);
  }
}

template <typename DeviceContext, typename T>
class RNNCPUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    int gate_num = 4;
    if (is_lstm(ctx)) {
      RnnGradFunc<LSTMGradCell<T>, SingleGradLayer, BidirGradLayer, T>(
          ctx, gate_num);
    } else if (is_gru(ctx)) {
      gate_num = 3;
      RnnGradFunc<GRUGradCell<T>, SingleGradLayer, BidirGradLayer, T>(ctx,
                                                                      gate_num);
      // run gru
    } else if (is_rnn_relu(ctx)) {
      gate_num = 1;
      RnnGradFunc<SimpleRNNGradCell<T, ReluGradFunctor>, SingleGradLayer,
                  BidirGradLayer, T>(ctx, gate_num);
      // run rnn
    } else if (is_rnn_tanh(ctx)) {
      gate_num = 1;
      RnnGradFunc<SimpleRNNGradCell<T, TanhGradFunctor>, SingleGradLayer,
                  BidirGradLayer, T>(ctx, gate_num);
    }
  }
};
}  // namespace operators
}  // namespace paddle