/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/unique_op.h" #include "paddle/fluid/operators/utils.h" #include "paddle/pten/kernels/funcs/blas/blas.h" #include "paddle/pten/kernels/funcs/math_function.h" namespace paddle { namespace operators { using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; using TensorList = std::vector; #define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \ inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \ const std::string& mode = ctx.Attr("mode"); \ return mode == #MODE_STR; \ } DEFINE_MODE_DETECTOR(lstm, LSTM); DEFINE_MODE_DETECTOR(gru, GRU); DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU); DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH); void SwapPoniter(Tensor** a, Tensor** b) { Tensor* c = *a; *a = *b; *b = c; } template void create_mask_matrix(const framework::ExecutionContext& context, const Tensor* sequence_length, Tensor* mask_matrix, const bool& is_reverse, int* min_seq_len) { const auto& seq_len_vec = GetDataFromTensor(sequence_length); const int& table_width = mask_matrix->dims()[0]; Tensor temp; temp.Resize( pten::make_ddim({mask_matrix->dims()[1], mask_matrix->dims()[0]})); T* data_temp = temp.mutable_data(context.GetPlace()); std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast(1.0)); *min_seq_len = table_width; for (unsigned int i = 0; i < seq_len_vec.size(); i++) { // reset the mask matrix *min_seq_len = std::min(seq_len_vec[i], *min_seq_len); if (seq_len_vec[i] == table_width) { continue; } if (is_reverse) { std::fill(data_temp + i * table_width, data_temp + (i + 1) * table_width - seq_len_vec[i], static_cast(0)); } else { std::fill(data_temp + i * table_width + seq_len_vec[i], data_temp + (i + 1) * table_width, static_cast(0)); } } mask_matrix->mutable_data(context.GetPlace()); std::vector trans_vec; trans_vec.emplace_back(1); trans_vec.emplace_back(0); auto& dev_ctx = context.template device_context(); TransCompute(2, dev_ctx, temp, mask_matrix, trans_vec); } template struct Cell { virtual ~Cell() {} virtual void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, const Tensor* weight_hh, const Tensor* init_h, const Tensor* init_c, Tensor* last_h, Tensor* last_c, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) const {} }; template class EigenActivationFunctor, math::detail::ActivationType act_type> struct SimpleRNNCell : Cell { void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, const Tensor* weight_hh, const Tensor* init_h, const Tensor* init_c, Tensor* last_h, Tensor* last_c, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) const override { auto blas = pten::funcs::GetBlas(*device_ctx); auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false); auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; // convert the batch matmul to matmul, this operator could be speed faster blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast(1.0), input, static_cast(1.0)); auto z = EigenVector::Flatten( GET_DATA_SAFELY(input, "Input", "z", "Activation")); auto hidden = EigenVector::Flatten( GET_DATA_SAFELY(output, "Output", "hidden", "Activation")); auto* place = device_ctx->eigen_device(); EigenActivationFunctor functor; functor(*place, z, hidden); } }; template struct GRUCell : Cell { void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, const Tensor* weight_hh, const Tensor* init_h, const Tensor* init_c, Tensor* last_h, Tensor* last_c, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) const override { auto blas = pten::funcs::GetBlas(*device_ctx); auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false); auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; // convert the batch matmul to matmul, this operator could be speed faster blas.MatMul(*init_h, mat_dim_a, *weight_hh_gru, mat_dim_b, static_cast(1.0), input, static_cast(1.0)); size_t frame_size = init_h->dims()[2]; size_t batch_size = init_h->dims()[1]; math::GRUMetaValue gru_value; gru_value.gate_weight = weight_hh->data(); gru_value.state_weight = weight_hh->data() + 2 * frame_size * frame_size; gru_value.reset_bias = bias_hh->data() + 2 * frame_size; gru_value.gate_value = input->data(); gru_value.reset_output_value = last_c->data(); gru_value.output_value = output->data(); gru_value.prev_out_value = init_h->data(); auto gate_act = math::detail::GetActivationType("sigmoid_v2"); auto cand_act = math::detail::GetActivationType("tanh_v2"); math::GRUUnitFunctorV2::compute( *device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act); } }; template struct LSTMCell : Cell { void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, const Tensor* weight_hh, const Tensor* init_h, const Tensor* init_c, Tensor* last_h, Tensor* last_c, Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, Tensor* weight_hh_gru) const override { auto blas = pten::funcs::GetBlas(*device_ctx); auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(init_h->dims(), 0, false); auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, true); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; // convert the batch matmul to matmul, this operator could be speed faster blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast(1.0), input, static_cast(1.0)); math::LstmMetaValue lstm_value; lstm_value.check_ig = nullptr; lstm_value.check_fg = nullptr; lstm_value.check_og = nullptr; auto gate_act = math::detail::GetActivationType("sigmoid_v2"); auto cell_act = math::detail::GetActivationType("tanh_v2"); auto cand_act = math::detail::GetActivationType("tanh_v2"); size_t frame_size = init_h->dims()[2]; size_t batch_size = init_h->dims()[1]; Tensor cell_pre_act; if (last_c_act == nullptr) { /* is test */ cell_pre_act.mutable_data(init_h->dims(), device_ctx->GetPlace()); last_c_act = &cell_pre_act; } lstm_value.prev_state_value = init_c->data(); lstm_value.gate_value = input->data(); lstm_value.output_value = output->data(); lstm_value.state_value = last_c->data(); lstm_value.state_active_value = last_c_act->data(); T cell_clip = 0.0; math::LstmUnitFunctor::compute( *device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act, cell_act, cand_act, false); } }; template void dropout_helper(const framework::ExecutionContext& context, Tensor* x, Tensor* y, const Tensor* mask, const float& dropout_prob) { auto& place = *context.template device_context() .eigen_device(); auto dropout_mask = EigenVector::Flatten(*mask); auto in = EigenVector::Flatten(*x); auto out = EigenVector::Flatten(*y); if (dropout_prob == 1.0f) { out.device(place) = static_cast(0) * in; } else { out.device(place) = in * dropout_mask.cast() / static_cast(1.0f - dropout_prob); } } template void dropout_cpu_function_inplace(const framework::ExecutionContext& context, Tensor* x, Tensor* y, Tensor* mask, const float& dropout_prob, const int& seed_number, bool is_test, bool* is_has_reset) { if (is_test) { return; } size_t size = pten::product(x->dims()); auto* mask_data = mask->data(); if (!(*is_has_reset)) { // Special case when dropout_prob is 1.0 if (dropout_prob == 1.0f) { std::fill(mask_data, mask_data + size, static_cast(0)); } else { auto engine = framework::GetCPURandomEngine(seed_number); std::uniform_real_distribution dist(0, 1); for (size_t i = 0; i < size; ++i) { if (dist(*engine) < dropout_prob) { mask_data[i] = 0; } else { mask_data[i] = 1; } } } *is_has_reset = true; } dropout_helper(context, x, y, mask, dropout_prob); } template void dropout_cpu_grad_function_inplace( const framework::ExecutionContext& context, Tensor* grad_x, const Tensor* mask, const float& dropout_prob) { dropout_helper(context, grad_x, grad_x, mask, dropout_prob); } template struct Layer { explicit Layer(const CellType& cell) : cell_(cell) {} virtual ~Layer() {} void preprocess(const framework::ExecutionContext& context, const Tensor* input, const Tensor& weight, const Tensor& bias_ih, const Tensor& bias_hh, Tensor* cache_input, bool is_test) { // crate the temp input for the X * W_ih^T + Bias_ih auto& dev_ctx = context.template device_context(); const int& hidden_size = weight.dims()[0]; cache_input->Resize( pten::make_ddim({input->dims()[0], input->dims()[1], hidden_size})); if (is_test) { cache_input->mutable_data(context.GetPlace()); } auto blas = pten::funcs::GetBlas(dev_ctx); auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(input->dims(), 0, false); auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(weight.dims(), 0, true); // convert the batch matmul to matmul, this operator could be speed faster mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast(1.0), cache_input, static_cast(0)); auto in = framework::EigenMatrix::Reshape( *cache_input, cache_input->dims().size() - 1); auto bias_ih_tmp = framework::EigenMatrix::From( bias_ih, pten::make_ddim({1, bias_ih.dims()[0]})); const int& row_num = pten::product(cache_input->dims()) / cache_input->dims()[2]; in = in + bias_ih_tmp.broadcast(Eigen::DSizes(row_num, 1)); if (is_gru(context)) { // reset_gate update_gate cell_gate = [1, 1, 0] Tensor bias_hh_tmp; bias_hh_tmp.Resize({bias_hh.numel()}); bias_hh_tmp.mutable_data(context.GetPlace()); framework::TensorCopy(bias_hh, context.GetPlace(), dev_ctx, &bias_hh_tmp); bias_hh_tmp.Resize({3, bias_hh_tmp.numel() / 3}); auto bias_hh_tmp_unbind = Unbind(bias_hh_tmp); pten::funcs::SetConstant zero; zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast(0.0)); auto bias_hh_after_mask = framework::EigenMatrix::From( bias_hh_tmp, pten::make_ddim({1, bias_hh.dims()[0]})); in = in + bias_hh_after_mask.broadcast(Eigen::DSizes(row_num, 1)); } else { auto bias_hh_no_mask = framework::EigenMatrix::From( bias_hh, pten::make_ddim({1, bias_hh.dims()[0]})); in = in + bias_hh_no_mask.broadcast(Eigen::DSizes(row_num, 1)); } } void postprocess(const framework::ExecutionContext& context, Tensor* output, const Tensor* init_h, const Tensor* init_c, Tensor* last_h, Tensor* last_c, const Tensor& mask_tensor) { // in the output, if mask flag is 0, we will retun the zero data auto& place = *context.template device_context() .eigen_device(); auto out = framework::EigenMatrix::Reshape(*output, output->dims().size() - 1); auto mask = framework::EigenMatrix::From( mask_tensor, pten::make_ddim({mask_tensor.dims()[1], 1})); auto pre_h = framework::EigenMatrix::Reshape(*init_h, init_h->dims().size() - 1); auto curr_h = framework::EigenMatrix::Reshape(*last_h, last_h->dims().size() - 1); auto mask_broadcast = mask.broadcast(Eigen::DSizes(1, output->dims()[2])); curr_h.device(place) = out * mask_broadcast + pre_h * (1 - mask_broadcast); out.device(place) = out * mask_broadcast; if (is_lstm(context)) { auto pre_c = framework::EigenMatrix::Reshape( *init_c, init_c->dims().size() - 1); auto curr_c = framework::EigenMatrix::Reshape( *last_c, last_c->dims().size() - 1); curr_c.device(place) = curr_c * mask_broadcast + pre_c * (1 - mask_broadcast); } } virtual void operator()(const framework::ExecutionContext& context, const Tensor* input, const TensorList& vec, const TensorList& init_h, const TensorList& init_c, const Tensor* sequence_length, TensorList last_h, TensorList last_c, Tensor* output, const int& layer_idx, const int& gate_num, Tensor* gate_value, Tensor* cell_value, Tensor* cell_act_value, bool is_test) {} void RunTestIter(const framework::ExecutionContext& context, const Tensor* input, const TensorList& vec, const TensorList& init_h, const TensorList& init_c, const Tensor* sequence_length, TensorList* last_h_ptr, TensorList* last_c_ptr, Tensor* output, int layer_idx, Tensor* gate_value, Tensor* cell_value, Tensor* cell_act_value, bool is_bidirect, int offset) { bool is_reverse = false; if (is_bidirect) { layer_idx = 2 * layer_idx + offset; if (offset > 0) { is_reverse = true; } } auto& dev_ctx = context.template device_context(); const int& time_step = input->dims()[0]; this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4], vec[3 + offset * 4], gate_value, true); auto input_tensors = Unbind(*gate_value); auto output_tensors = Unbind(*output); if (is_reverse) { std::reverse(input_tensors.begin(), input_tensors.end()); std::reverse(output_tensors.begin(), output_tensors.end()); } TensorList mask_tensor_list; // construct the mask matrix for the mask bool has_sequence_length = false; if (sequence_length != nullptr) { has_sequence_length = true; } Tensor mask_matrix; int mask_min_length = time_step; if (has_sequence_length) { mask_matrix.Resize(pten::make_ddim({time_step, input->dims()[1]})); create_mask_matrix(context, sequence_length, &mask_matrix, is_reverse, &mask_min_length); mask_tensor_list = Unbind(mask_matrix); } if (is_reverse) { mask_min_length = mask_min_length - time_step + 1; } bool has_allocate_mem_c = false; bool has_use_last_h_holder = false; const int& reverse_flag = is_reverse ? -1 : 1; // define the init_h holder for the swap Tensor init_h_temp; framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx, &init_h_temp); Tensor* init_h_holder = &init_h_temp; Tensor* last_h_holder = nullptr; if (0 < mask_min_length) { last_h_holder = &(output_tensors[0]); } else { last_h_holder = &(*last_h_ptr)[layer_idx]; has_use_last_h_holder = true; } Tensor* init_c_holder = nullptr; const Tensor* init_c_temp_holder = nullptr; Tensor init_c_temp; Tensor* last_c_holder = nullptr; Tensor last_c_temp; if (is_lstm(context)) { last_c_holder = &(*last_c_ptr)[layer_idx]; init_c_temp_holder = &init_c[layer_idx]; } else if (is_gru(context)) { // for reset output value last_c_temp.Resize(init_h[layer_idx].dims()); last_c_temp.mutable_data(context.GetPlace()); last_c_holder = &last_c_temp; } Tensor weight_hh_tmp; // for gru if (is_gru(context)) { weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); weight_hh_tmp.mutable_data(context.GetPlace()); framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx, &weight_hh_tmp); weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3}); auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp); pten::funcs::SetConstant zero; zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast(0.0)); weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); } for (int i = 0; i < time_step; i++) { bool in_mask = (reverse_flag * i) >= mask_min_length; if (i > 0) { if (!has_allocate_mem_c) { if (is_lstm(context) || is_gru(context)) { init_c_temp.Resize(init_h[layer_idx].dims()); init_c_temp.mutable_data(context.GetPlace()); init_c_holder = &init_c_temp; } has_allocate_mem_c = true; } SwapPoniter(&init_c_holder, &last_c_holder); init_c_temp_holder = init_c_holder; } cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder, init_c_temp_holder, last_h_holder, last_c_holder, nullptr, &output_tensors[i], &vec[3 + offset * 4] /* bias_hh */, &weight_hh_tmp); if (in_mask) { this->postprocess(context, &output_tensors[i], init_h_holder, init_c_temp_holder, last_h_holder, last_c_holder, mask_tensor_list[i]); } // prepare next step if (i + 1 < time_step) { bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length; if (next_step_mask) { if (!has_use_last_h_holder) { init_h_holder = &(*last_h_ptr)[layer_idx]; } } else { init_h_holder = &(output_tensors[i + 1]); } SwapPoniter(&init_h_holder, &last_h_holder); } } if (has_sequence_length) { if (last_h_holder != &(*last_h_ptr)[layer_idx]) { framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx, &(*last_h_ptr)[layer_idx]); } } else { framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(), dev_ctx, &(*last_h_ptr)[layer_idx]); } if (time_step % 2 == 0) { if (is_lstm(context)) { framework::TensorCopy(*last_c_holder, context.GetPlace(), dev_ctx, &(*last_c_ptr)[layer_idx]); } } } void RunIter(const framework::ExecutionContext& context, const Tensor* input, const TensorList& vec, const TensorList& init_h, const TensorList& init_c, const Tensor* sequence_length, TensorList* last_h_ptr, TensorList* last_c_ptr, Tensor* output, int layer_idx, Tensor* gate_value, Tensor* cell_value, Tensor* cell_act_value, bool is_bidirect, int offset, bool is_test) { if (is_test) { RunTestIter(context, input, vec, init_h, init_c, sequence_length, last_h_ptr, last_c_ptr, output, layer_idx, gate_value, cell_value, cell_act_value, is_bidirect, offset); return; } bool is_reverse = false; if (is_bidirect) { layer_idx = 2 * layer_idx + offset; if (offset > 0) { is_reverse = true; } } auto& dev_ctx = context.template device_context(); const int& time_step = input->dims()[0]; this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4], vec[3 + offset * 4], gate_value, is_test); auto input_tensors = Unbind(*gate_value); auto output_tensors = Unbind(*output); if (is_reverse) { std::reverse(input_tensors.begin(), input_tensors.end()); std::reverse(output_tensors.begin(), output_tensors.end()); } TensorList mask_tensor_list; // construct the mask matrix for the mask bool has_sequence_length = false; if (sequence_length != nullptr) { has_sequence_length = true; } Tensor mask_matrix; int mask_min_length = time_step; if (has_sequence_length) { mask_matrix.Resize(pten::make_ddim({time_step, input->dims()[1]})); create_mask_matrix(context, sequence_length, &mask_matrix, is_reverse, &mask_min_length); mask_tensor_list = Unbind(mask_matrix); } if (is_reverse) { mask_min_length = mask_min_length - time_step + 1; } // define the init_h holder for the swap bool has_use_last_h_holder = false; const int& reverse_flag = is_reverse ? -1 : 1; TensorList cell_value_tensors; TensorList cell_act_value_tensors; Tensor init_h_temp; framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx, &init_h_temp); Tensor* init_h_holder = &init_h_temp; Tensor* last_h_holder = nullptr; if (0 < mask_min_length) { last_h_holder = &(output_tensors[0]); } else { last_h_holder = &(*last_h_ptr)[layer_idx]; has_use_last_h_holder = true; } const Tensor* init_c_holder = nullptr; Tensor* last_c_holder = nullptr; Tensor* last_c_act_holder = nullptr; if (is_lstm(context) || is_gru(context)) { cell_value->Resize({time_step, cell_value->numel() / time_step}); cell_value_tensors = Unbind(*cell_value); if (is_lstm(context)) { cell_act_value->Resize( {time_step, cell_act_value->numel() / time_step}); cell_act_value_tensors = Unbind(*cell_act_value); } } Tensor weight_hh_tmp; // for gru if (is_gru(context)) { weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); weight_hh_tmp.mutable_data(context.GetPlace()); framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx, &weight_hh_tmp); weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3}); auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp); pten::funcs::SetConstant zero; zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast(0.0)); weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); } for (int i = 0; i < time_step; i++) { bool in_mask = (reverse_flag * i) >= mask_min_length; if (is_lstm(context)) { if (i == 0) { init_c_holder = &init_c[layer_idx]; } else { init_c_holder = &cell_value_tensors[i - 1]; } cell_value_tensors[i].Resize(init_c[layer_idx].dims()); cell_act_value_tensors[i].Resize(init_c[layer_idx].dims()); last_c_holder = &cell_value_tensors[i]; last_c_act_holder = &cell_act_value_tensors[i]; } else if (is_gru(context)) { cell_value_tensors[i].Resize(init_h[layer_idx].dims()); last_c_holder = &cell_value_tensors[i]; } cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder, init_c_holder, last_h_holder, last_c_holder, last_c_act_holder, &output_tensors[i], &vec[3 + offset * 4] /* bias_hh */, &weight_hh_tmp); if (in_mask) { this->postprocess(context, &output_tensors[i], init_h_holder, init_c_holder, last_h_holder, last_c_holder, mask_tensor_list[i]); } // prepare next step if (i + 1 < time_step) { bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length; if (next_step_mask) { if (!has_use_last_h_holder) { init_h_holder = &(*last_h_ptr)[layer_idx]; } } else { init_h_holder = &(output_tensors[i + 1]); } SwapPoniter(&init_h_holder, &last_h_holder); } } if (has_sequence_length) { if (last_h_holder != &(*last_h_ptr)[layer_idx]) { framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx, &(*last_h_ptr)[layer_idx]); } } else { framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(), dev_ctx, &(*last_h_ptr)[layer_idx]); } if (is_lstm(context)) { framework::TensorCopy(cell_value_tensors[time_step - 1], context.GetPlace(), dev_ctx, &(*last_c_ptr)[layer_idx]); } } // Cell for the rnn module CellType cell_; }; template struct SingleLayer : public Layer { explicit SingleLayer(const CellType& cell) : Layer(cell) {} void operator()(const framework::ExecutionContext& context, const Tensor* input, const TensorList& vec, const TensorList& init_h, const TensorList& init_c, const Tensor* sequence_length, TensorList last_h, TensorList last_c, Tensor* output, const int& layer_idx, const int& gate_num, Tensor* gate_value, Tensor* cell_value, Tensor* cell_act_value, bool is_test) { this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h, &last_c, output, layer_idx, gate_value, cell_value, cell_act_value, false, 0, is_test); } }; template struct BidirLayer : public Layer { explicit BidirLayer(const CellType& cell) : Layer(cell) {} void operator()(const framework::ExecutionContext& context, const Tensor* input, const TensorList& vec, const TensorList& init_h, const TensorList& init_c, const Tensor* sequence_length, TensorList last_h, TensorList last_c, Tensor* output, const int& layer_idx, const int& gate_num, Tensor* gate_value, Tensor* cell_value, Tensor* cell_act_value, bool is_test) { TensorList output_vec(2); Tensor forward_input_w, forward_cell_value, forward_cell_act_value; Tensor backward_input_w, backward_cell_value, backward_cell_act_value; int time_step = input->dims()[0]; int batch_size = input->dims()[1]; int hidden_size = output->dims()[2]; for (int i = 0; i < 2; ++i) { output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); output_vec[i].mutable_data(context.GetPlace()); } if (!is_test) { gate_value->Resize({2, gate_value->numel() / 2}); forward_input_w = gate_value->Slice(0, 1); backward_input_w = gate_value->Slice(1, 2); if (is_lstm(context) || is_gru(context)) /* for lstm and gru */ { cell_value->Resize({2, cell_value->numel() / 2}); cell_act_value->Resize({2, cell_act_value->numel() / 2}); forward_cell_value = cell_value->Slice(0, 1); backward_cell_value = cell_value->Slice(1, 2); if (is_lstm(context)) { forward_cell_act_value = cell_act_value->Slice(0, 1); backward_cell_act_value = cell_act_value->Slice(1, 2); } } } this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h, &last_c, &output_vec[0], layer_idx, &forward_input_w, &forward_cell_value, &forward_cell_act_value, true, 0, is_test); this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h, &last_c, &output_vec[1], layer_idx, &backward_input_w, &backward_cell_value, &backward_cell_act_value, true, 1, is_test); // concat the the output result auto& dev_ctx = context.template device_context(); paddle::operators::math::ConcatFunctor concat_functor; concat_functor(dev_ctx, output_vec, static_cast(2), output); } }; template void SplitReserveData(const framework::ExecutionContext& ctx, TensorType* reserve_data, Tensor* gate_data, Tensor* cell_data, Tensor* cell_act_data, Tensor* hidden_data, int direction_num, const int& time_step, const int& batch_size, const int& hidden_size, const int& gate_num, const int& num_layers) { const int& gate_data_idx = gate_num * num_layers; const int& cell_data_idx = (gate_num + 1) * num_layers; const int& cell_act_data_idx = (gate_num + 2) * num_layers; // simple rnn int hidden_data_start_idx = gate_data_idx; *gate_data = reserve_data->Slice(0, gate_data_idx); if (is_lstm(ctx)) { *cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx); *cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx); hidden_data_start_idx = cell_act_data_idx; } else if (is_gru(ctx)) { *cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx); hidden_data_start_idx = cell_data_idx; } int hidden_data_idx = hidden_data_start_idx + (num_layers - 1); if (num_layers > 1) { *hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx); } } template void reset_parameter_vector(const std::vector& raw_params_vec, const int& num_layers, const int& gate_num, const bool& is_bidirec, std::vector* params_vec) { // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers const int& direction_num = is_bidirec ? 2 : 1; const int& layer_weight_size = 4 * direction_num; const int& all_weight_size = num_layers * layer_weight_size; const int& bias_start_idx = all_weight_size / 2; for (int i = 0; i < num_layers; i++) { TensorList tensor_list; tensor_list.reserve(layer_weight_size); for (int j = 0; j < layer_weight_size; j++) { Tensor tensor_holder; tensor_list.emplace_back(tensor_holder); } for (int j = 0; j < layer_weight_size; j++) { int k = j % 4; const int& section = j / 4; int tensor_idx = i * 2 * direction_num + section * 2 + k % 2; if (k >= 2) { tensor_idx += bias_start_idx; } tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]); } params_vec->emplace_back(tensor_list); } } template void AllocateReserveData(const framework::ExecutionContext& ctx, Tensor* reserve_data, Tensor* gate_data, Tensor* cell_data, Tensor* cell_act_data, Tensor* hidden_data, const Tensor* input, bool is_bidirec, int num_layers, int gate_num, int hidden_size) { const int& direction_num = is_bidirec ? 2 : 1; const int& time_step = input->dims()[0]; const int& batch_size = input->dims()[1]; const int& block_size = direction_num * time_step * batch_size * hidden_size; int hidden_data_idx = (num_layers - 1); if (is_lstm(ctx)) { hidden_data_idx += (gate_num + 2) * num_layers; } else if (is_gru(ctx)) { hidden_data_idx += (gate_num + 1) * num_layers; } else { hidden_data_idx += gate_num * num_layers; } reserve_data->Resize({hidden_data_idx, block_size}); reserve_data->mutable_data(ctx.GetPlace()); SplitReserveData(ctx, reserve_data, gate_data, cell_data, cell_act_data, hidden_data, direction_num, time_step, batch_size, hidden_size, gate_num, num_layers); } template class LayerT, template class SingleLayerT, template class BidirLayerT, typename T> void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input, const std::vector weight_list, const Tensor* init_h, const Tensor* init_c, const Tensor* sequence_length, Tensor* last_h, Tensor* last_c, Tensor* output, Tensor* dropout_mask, const int& num_layers, const int& gate_num, const int& input_size, const int& hidden_size, const bool& is_bidirec, const std::string& cell_type, const float& dropout_prob, bool is_test, const int& seed, Tensor* reserve_data) { const int& direction_num = is_bidirec ? 2 : 1; const auto& init_h_dims = init_h->dims(); PADDLE_ENFORCE_EQ(init_h_dims[0], num_layers * direction_num, platform::errors::InvalidArgument( "The num_layers of in RNN layer must be the same as " "first dim of init hidden, but received" " num_layers:%d, dim:%d", num_layers, init_h_dims[0])); if (is_lstm(ctx)) { const auto& init_c_dims = init_c->dims(); PADDLE_ENFORCE_EQ(init_c_dims[0], num_layers * direction_num, platform::errors::InvalidArgument( "The num_layers of in RNN layer must be the same as " "first dim of cell state hidden, but received" " num_layers:%d, dim:%d", num_layers, init_h_dims[0])); } CellType cell; std::vector parameter_lists; parameter_lists.reserve(num_layers); reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec, ¶meter_lists); Tensor gate_data, cell_data, cell_act_data, hidden_data; if (!is_test) { AllocateReserveData( ctx, reserve_data, &gate_data, &cell_data, &cell_act_data, &hidden_data, input, is_bidirec, num_layers, gate_num, hidden_size); gate_data.Resize({num_layers, gate_data.numel() / num_layers}); cell_data.Resize({num_layers, cell_data.numel() / num_layers}); cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers}); if (num_layers > 1) { hidden_data.Resize( {num_layers - 1, hidden_data.numel() / (num_layers - 1)}); } } Tensor* input_holder; Tensor* output_holder = output; Tensor temp; bool has_allocate_mem = false; auto init_h_unbind = Unbind(*init_h); auto last_h_unbind = Unbind(*last_h); TensorList init_c_unbind, last_c_unbind; if (is_lstm(ctx)) { init_c_unbind = Unbind(*init_c); last_c_unbind = Unbind(*last_c); } Tensor curr_gate_data, curr_cell_data, curr_cell_act_data; Tensor curr_hidden_data, prev_hidden_data; bool has_dropout_reset = false; for (int i = 0; i < num_layers; i++) { if (!is_test) { if (cell_data.numel() > 0) /** for lstm, gru **/ { curr_cell_data = cell_data.Slice(i, i + 1); } if (cell_act_data.numel() > 0) /*for lstm*/ { curr_cell_act_data = cell_act_data.Slice(i, i + 1); } curr_gate_data = gate_data.Slice(i, i + 1); output_holder = output; if (i < num_layers - 1 && num_layers > 1) { curr_hidden_data = hidden_data.Slice(i, i + 1); curr_hidden_data.Resize(output->dims()); output_holder = &curr_hidden_data; } } if (i > 0) { if (!has_allocate_mem) { temp.Resize(output->dims()); temp.mutable_data(ctx.GetPlace()); input_holder = &temp; has_allocate_mem = true; } if (!is_test) { prev_hidden_data = hidden_data.Slice(i - 1, i); input_holder->Resize(output->dims()); if (dropout_prob != 0) { dropout_cpu_function_inplace(ctx, &prev_hidden_data, input_holder, dropout_mask, dropout_prob, seed, is_test, &has_dropout_reset); } else { input_holder = &prev_hidden_data; input_holder->Resize(output->dims()); } } else { SwapPoniter(&output_holder, &input_holder); } } const Tensor* input_temp_holder = input; if (i > 0) { input_temp_holder = input_holder; } LayerT* layer; SingleLayerT slayer(cell); BidirLayerT blayer(cell); if (is_bidirec) { layer = &blayer; } else { layer = &slayer; } (*layer)(ctx, input_temp_holder, parameter_lists[i], init_h_unbind, init_c_unbind, sequence_length, last_h_unbind, last_c_unbind, output_holder, i, gate_num, &curr_gate_data, &curr_cell_data, &curr_cell_act_data, is_test); } if (num_layers % 2 == 0) { framework::TensorCopy( *output_holder, ctx.GetPlace(), ctx.template device_context(), output); } } template class RNNCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto pre_state = ctx.MultiInput("PreState"); auto weight_list = ctx.MultiInput("WeightList"); auto state = ctx.MultiOutput("State"); auto* output = ctx.Output("Out"); auto* dropout_mask = ctx.Output("DropoutState"); auto* reserve_data = ctx.Output("Reserve"); const int& num_layers = ctx.Attr("num_layers"); const bool& is_bidirec = ctx.Attr("is_bidirec"); const int& input_size = ctx.Attr("input_size"); const int& hidden_size = ctx.Attr("hidden_size"); const float& dropout_prob = ctx.Attr("dropout_prob"); const std::string& mode = ctx.Attr("mode"); const int& seed = ctx.Attr("seed"); bool is_test = ctx.HasAttr("is_test") ? ctx.Attr("is_test") : false; bool has_seq_length = ctx.HasInput("SequenceLength"); const Tensor* sequence_length = nullptr; if (has_seq_length) { sequence_length = ctx.Input("SequenceLength"); } if (dropout_mask->IsInitialized()) { if (dropout_mask->numel() != output->numel()) dropout_mask->clear(); } dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); auto& dev_ctx = ctx.template device_context(); pten::funcs::SetConstant ones; ones(dev_ctx, dropout_mask, static_cast(1)); // init the output and allocate the memory output->mutable_data(ctx.GetPlace()); int gate_num = 4; state[0]->mutable_data(ctx.GetPlace()); if (is_lstm(ctx)) { state[1]->mutable_data(ctx.GetPlace()); RnnFunc, Layer, SingleLayer, BidirLayer, T>( ctx, input, weight_list, pre_state[0], pre_state[1], sequence_length, state[0], state[1], output, dropout_mask, num_layers, gate_num, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, seed, reserve_data); } else if (is_rnn_relu(ctx)) { gate_num = 1; RnnFunc< SimpleRNNCell, Layer, SingleLayer, BidirLayer, T>( ctx, input, weight_list, pre_state[0], nullptr, sequence_length, state[0], nullptr, output, dropout_mask, num_layers, gate_num, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, seed, reserve_data); } else if (is_rnn_tanh(ctx)) { gate_num = 1; RnnFunc< SimpleRNNCell, Layer, SingleLayer, BidirLayer, T>( ctx, input, weight_list, pre_state[0], nullptr, sequence_length, state[0], nullptr, output, dropout_mask, num_layers, gate_num, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, seed, reserve_data); } else if (is_gru(ctx)) { gate_num = 3; RnnFunc, Layer, SingleLayer, BidirLayer, T>( ctx, input, weight_list, pre_state[0], nullptr, sequence_length, state[0], nullptr, output, dropout_mask, num_layers, gate_num, input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, seed, reserve_data); } } }; template void create_lstm_value(math::LstmMetaValue* lstm_value) { lstm_value->check_ig = nullptr; lstm_value->check_fg = nullptr; lstm_value->check_og = nullptr; } template void create_lstm_grad(math::LstmMetaGrad* lstm_grad) { lstm_grad->check_ig_grad = nullptr; lstm_grad->check_fg_grad = nullptr; lstm_grad->check_og_grad = nullptr; } template void create_tensor_by_list(const framework::ExecutionContext& context, Tensor* dst, const std::vector& v) { int tensor_size = v.size(); dst->Resize({tensor_size}); dst->mutable_data(context.GetPlace()); int size = v.size(); for (int i = 0; i < size; ++i) { dst->data()[i] = v[i]; } } template struct GradLayer { explicit GradLayer(const GradCellType& cell) : cell_(cell) {} virtual ~GradLayer() {} void run_rnn_grad_function( const framework::ExecutionContext& context, const platform::CPUDeviceContext& device_ctx, const Tensor* input, Tensor* input_grad, const Tensor* sequence_length, std::vector* init_h_unbind, std::vector* init_c_unbind, std::vector* init_h_grad_unbind, std::vector* init_c_grad_unbind, Tensor* layer_grad_gate_tensor, std::vector* layer_gate_tensor_unbind, std::vector* layer_grad_gate_tensor_unbind, std::vector* layer_state_tensor_unbind, std::vector* layer_act_state_tensor_unbind, std::vector* output_tensor_unbind, std::vector* output_grad_tensor_unbind, const TensorList& last_h_grad_unbind, const TensorList& last_c_grad_unbind, const std::vector& parameter_lists, std::vector* weight_list_grad, const int& layer_idx, const int& time_step, const bool& has_sequence_length, const bool& is_bidirec, const bool& is_reverse) { const int& direction_num = is_bidirec ? 2 : 1; const int& current_reverse_idx = is_reverse ? 1 : 0; const int& current_layer_idx = direction_num * layer_idx + current_reverse_idx; int begin_idx = 0; if (is_reverse) { begin_idx = time_step; } Tensor mask_matrix; TensorList mask_tensor_list; int mask_min_length = time_step; if (has_sequence_length) { mask_matrix.Resize(pten::make_ddim({time_step, input->dims()[1]})); create_mask_matrix(context, sequence_length, &mask_matrix, is_reverse, &mask_min_length); mask_tensor_list = Unbind(mask_matrix); } // copy the last_h, last_c for swaping pointer Tensor a, b; Tensor* dynamic_grad_last_h = &a; Tensor* dynamic_grad_last_c = &b; dynamic_grad_last_h->Resize(last_h_grad_unbind[current_layer_idx].dims()); dynamic_grad_last_h->mutable_data(context.GetPlace()); framework::TensorCopy(last_h_grad_unbind[current_layer_idx], context.GetPlace(), dynamic_grad_last_h); if (last_c_grad_unbind.size() > 0) { dynamic_grad_last_c->Resize(last_c_grad_unbind[current_layer_idx].dims()); dynamic_grad_last_c->mutable_data(context.GetPlace()); framework::TensorCopy(last_c_grad_unbind[current_layer_idx], context.GetPlace(), dynamic_grad_last_c); } else { dynamic_grad_last_c = nullptr; } Tensor c, d; Tensor* dynamic_grad_pre_h = &c; Tensor* dynamic_grad_pre_c = &d; pten::funcs::SetConstant zero; if (init_h_grad_unbind->size() > 0) { dynamic_grad_pre_h->ShareDataWith( (*init_h_grad_unbind)[current_layer_idx]); } else { dynamic_grad_pre_h->Resize(dynamic_grad_last_h->dims()); dynamic_grad_pre_h->mutable_data(context.GetPlace()); zero(device_ctx, dynamic_grad_pre_h, static_cast(0.0)); } if (init_c_grad_unbind->size() > 0) { dynamic_grad_pre_c->ShareDataWith( (*init_c_grad_unbind)[current_layer_idx]); } else { if (is_lstm(context) || is_gru(context)) { dynamic_grad_pre_c->Resize(dynamic_grad_last_h->dims()); dynamic_grad_pre_c->mutable_data(context.GetPlace()); if (is_gru(context)) { dynamic_grad_last_c = dynamic_grad_pre_c; } } else { dynamic_grad_pre_c = nullptr; } } if (is_reverse) { // must be reverse the input, output, input_grad, output_grad // the gate and grad_gate must be reverse std::reverse(layer_gate_tensor_unbind->begin(), layer_gate_tensor_unbind->end()); std::reverse(layer_grad_gate_tensor_unbind->begin(), layer_grad_gate_tensor_unbind->end()); /* if (has_sequence_length) { std::reverse(mask_tensor_list.begin(), mask_tensor_list.end()); }*/ std::reverse(output_tensor_unbind->begin(), output_tensor_unbind->end()); std::reverse(output_grad_tensor_unbind->begin(), output_grad_tensor_unbind->end()); } Tensor* weight_grad = &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 1]); weight_grad->mutable_data(context.GetPlace()); zero(device_ctx, weight_grad, static_cast(0.0)); Tensor* pre_hidden = nullptr; Tensor* pre_state = nullptr; Tensor* hidden = nullptr; if (is_gru(context)) { zero(device_ctx, &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]), static_cast(0.0)); } for (int i = time_step - 1; i >= 0; --i) { if (has_sequence_length) { this->mask_preprocess(context, &(*output_grad_tensor_unbind)[i], dynamic_grad_last_h, dynamic_grad_last_c, dynamic_grad_pre_h, dynamic_grad_pre_c, mask_tensor_list[i]); } else { this->preprocess(context, &(*output_grad_tensor_unbind)[i], dynamic_grad_last_h); } hidden = &(*output_tensor_unbind)[i]; if (i == 0) { pre_hidden = &(*init_h_unbind)[current_layer_idx]; if (init_c_unbind->size() > 0) { pre_state = &(*init_c_unbind)[current_layer_idx]; } } else { pre_hidden = &(*output_tensor_unbind)[i - 1]; if (layer_state_tensor_unbind->size() > 0) { pre_state = &(*layer_state_tensor_unbind)[begin_idx + i - 1]; } } this->cell_( context, &(*layer_gate_tensor_unbind)[i], &(*layer_state_tensor_unbind)[begin_idx + i], &(*layer_act_state_tensor_unbind)[begin_idx + i], hidden, &(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]), pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c, &(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h, dynamic_grad_pre_c, &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]), mask_tensor_list[i], has_sequence_length); SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h); SwapPoniter(&dynamic_grad_last_c, &dynamic_grad_pre_c); } // postproces for gradient for w_hi, X, bias_hi, bias_hh this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad, parameter_lists[layer_idx], &((*weight_list_grad)[layer_idx]), is_reverse); // copy the gradient to init_c init_h if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) { framework::TensorCopy(*dynamic_grad_last_h, context.GetPlace(), &((*init_h_grad_unbind)[current_layer_idx])); } if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) { framework::TensorCopy(*dynamic_grad_last_c, context.GetPlace(), &((*init_c_grad_unbind)[current_layer_idx])); } } virtual void operator()( const framework::ExecutionContext& context, const Tensor* input, const Tensor* output, const TensorList& init_h_unbind, const TensorList& init_c_unbind, const TensorList& last_h_grad_unbind, const TensorList& last_c_grad_unbind, const TensorList& gate_tensor_unbind, const TensorList& state_tensor_unbind, const TensorList& act_state_tensor_unbind, const Tensor* output_grad, const std::vector& parameter_lists, const Tensor* sequence_length, Tensor* input_grad, TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind, const std::vector& weight_list_grad, const int& layer_idx, const int& gate_num) {} void preprocess(const framework::ExecutionContext& context, const Tensor* grad_output, Tensor* grad_last_h) { auto& place = *context.template device_context() .eigen_device(); auto output_grad = framework::EigenMatrix::Reshape( *grad_output, grad_output->dims().size() - 1); auto last_h_grad = framework::EigenMatrix::Reshape( *grad_last_h, grad_last_h->dims().size() - 1); // the output gradient contribute the gradient to last_h last_h_grad.device(place) = last_h_grad + output_grad; } void mask_preprocess(const framework::ExecutionContext& context, const Tensor* grad_output, Tensor* grad_last_h, Tensor* grad_last_c, Tensor* grad_pre_h, Tensor* grad_pre_c, const Tensor& mask_tensor) { auto& place = *context.template device_context() .eigen_device(); auto mask = framework::EigenMatrix::From( mask_tensor, pten::make_ddim({mask_tensor.dims()[1], 1})); auto mask_broadcast = mask.broadcast(Eigen::DSizes(1, grad_output->dims()[2])); auto last_h_grad = framework::EigenMatrix::Reshape( *grad_last_h, grad_last_h->dims().size() - 1); auto pre_h_grad = framework::EigenMatrix::Reshape( *grad_pre_h, grad_pre_h->dims().size() - 1); auto output_grad = framework::EigenMatrix::Reshape( *grad_output, grad_output->dims().size() - 1); last_h_grad.device(place) = last_h_grad + output_grad * mask_broadcast; pre_h_grad.device(place) = (1 - mask_broadcast) * last_h_grad; last_h_grad.device(place) = mask_broadcast * last_h_grad; if (grad_last_c && grad_pre_c && is_lstm(context)) { auto last_c_grad = framework::EigenMatrix::Reshape( *grad_last_c, grad_last_c->dims().size() - 1); auto pre_c_grad = framework::EigenMatrix::Reshape( *grad_pre_c, grad_pre_c->dims().size() - 1); pre_c_grad.device(place) = (1 - mask_broadcast) * last_c_grad; last_c_grad.device(place) = mask_broadcast * last_c_grad; } } void postprocess(const framework::ExecutionContext& context, const Tensor& grad_gate, const Tensor& input, Tensor* input_grad, const TensorList& parameters, TensorList* grad_parameters, const int& is_reverse) { // we get the grad_gate step by step, and need to bradocast the grad to the // grad_w_hi, grad_bias_hi, grad_bias_hh int begin_idx = 0; if (is_reverse) { begin_idx = 4; } auto& device_ctx = context.template device_context(); auto blas = pten::funcs::GetBlas(device_ctx); // calc the gradient for the w_hi auto mat_dim_out_grad = pten::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, true); auto mat_dim_input = pten::funcs::CreateMatrixDescriptor(input.dims(), 0, false); mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_; mat_dim_out_grad.batch_size_ = 0; mat_dim_input.height_ *= mat_dim_input.batch_size_; mat_dim_input.batch_size_ = 0; blas.MatMul(grad_gate, mat_dim_out_grad, input, mat_dim_input, static_cast(1.0), &((*grad_parameters)[begin_idx + 0]), T(0)); // calc the gradient for the X auto mat_dim_out_grad_new = pten::funcs::CreateMatrixDescriptor(grad_gate.dims(), 0, false); mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_; mat_dim_out_grad_new.batch_size_ = 0; auto mat_dim_parameter = pten::funcs::CreateMatrixDescriptor(parameters[0].dims(), 0, false); blas.MatMul(grad_gate, mat_dim_out_grad_new, parameters[begin_idx + 0], mat_dim_parameter, static_cast(1.0), input_grad, T(1)); // calc the gradient of Bias_hi, Bias_hh pten::funcs::ColwiseSum col_sum; Tensor tmp_grad_gate; tmp_grad_gate.ShareDataWith(grad_gate); tmp_grad_gate.Resize( {grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]}); col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2])); // Bias_hh if (!is_gru(context)) { col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3])); } } GradCellType cell_; }; template struct SingleGradLayer : GradLayer { // explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {} explicit SingleGradLayer(const GradCellType& cell) : GradLayer(cell) {} virtual ~SingleGradLayer() {} void operator()( const framework::ExecutionContext& context, const Tensor* input, const Tensor* output, std::vector* init_h_unbind, std::vector* init_c_unbind, const TensorList& last_h_grad_unbind, const TensorList& last_c_grad_unbind, const TensorList& gate_tensor_unbind, const TensorList& state_tensor_unbind, const TensorList& act_state_tensor_unbind, const Tensor* output_grad, const std::vector& parameter_lists, const Tensor* sequence_length, Tensor* input_grad, TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind, std::vector* weight_list_grad, const int& layer_idx, const int& gate_num) { auto& device_ctx = context.template device_context(); pten::funcs::SetConstant zero; zero(device_ctx, input_grad, static_cast(0.0)); const bool& is_bidirec = context.Attr("is_bidirec"); const int& time_step = input->dims()[0]; const int& batch_size = input->dims()[1]; const int& direction_num = is_bidirec ? 2 : 1; const int& hidden_size = context.Attr("hidden_size"); // in this section, create the gate_state_grad for the postprocess calculate // ubind the output, the output from [time_step, batch_size, hidden_size] auto output_tensor_unbind = Unbind(*output); auto output_grad_tensor_unbind = Unbind(*output_grad); auto layer_gate_tensor = gate_tensor_unbind[layer_idx]; layer_gate_tensor.Resize( {time_step * direction_num, batch_size, hidden_size * gate_num}); auto layer_gate_tensor_unbind = Unbind(layer_gate_tensor); // the gate_tensor and the grad_gate_tensor must be unbind Tensor layer_grad_gate_tensor; layer_grad_gate_tensor.Resize(layer_gate_tensor.dims()); layer_grad_gate_tensor.mutable_data(context.GetPlace()); auto layer_grad_gate_tensor_unbind = Unbind(layer_grad_gate_tensor); Tensor layer_state_tensor; TensorList layer_state_tensor_unbind; if (state_tensor_unbind.size() > 0) { layer_state_tensor = state_tensor_unbind[layer_idx]; layer_state_tensor.Resize( {time_step * direction_num, batch_size, hidden_size}); layer_state_tensor_unbind = Unbind(layer_state_tensor); } Tensor layer_act_state_tensor; TensorList layer_act_state_tensor_unbind; if (act_state_tensor_unbind.size() > 0) { layer_act_state_tensor = act_state_tensor_unbind[layer_idx]; layer_act_state_tensor.Resize( {time_step * direction_num, batch_size, hidden_size}); layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor); } const bool& has_sequence_length = sequence_length == nullptr ? false : true; this->run_rnn_grad_function( context, device_ctx, input, input_grad, sequence_length, init_h_unbind, init_c_unbind, init_h_grad_unbind, init_c_grad_unbind, &layer_grad_gate_tensor, &layer_gate_tensor_unbind, &layer_grad_gate_tensor_unbind, &layer_state_tensor_unbind, &layer_act_state_tensor_unbind, &output_tensor_unbind, &output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx, time_step, has_sequence_length, is_bidirec, false); } }; template void split_tensor_at_last_dim(const framework::ExecutionContext& context, const platform::CPUDeviceContext& dev_ctx, const Tensor* output, std::vector* output_vec, const int& axis) { std::vector shape_refer; (*output_vec)[0]->Resize( {output->dims()[0], output->dims()[1], output->dims()[2] / 2}); (*output_vec)[0]->mutable_data(context.GetPlace()); (*output_vec)[1]->Resize( {output->dims()[0], output->dims()[1], output->dims()[2] / 2}); (*output_vec)[1]->mutable_data(context.GetPlace()); shape_refer.emplace_back((*output_vec)[0]); shape_refer.emplace_back((*output_vec)[1]); math::SplitFunctor functor; functor(dev_ctx, *output, shape_refer, axis, output_vec); } template struct BidirGradLayer : GradLayer { explicit BidirGradLayer(const GradCellType& cell) : GradLayer(cell) {} virtual ~BidirGradLayer() {} void operator()( const framework::ExecutionContext& context, const Tensor* input, const Tensor* output, std::vector* init_h_unbind, std::vector* init_c_unbind, const TensorList& last_h_grad_unbind, const TensorList& last_c_grad_unbind, const TensorList& gate_tensor_unbind, const TensorList& state_tensor_unbind, const TensorList& act_state_tensor_unbind, const Tensor* output_grad, const std::vector& parameter_lists, const Tensor* sequence_length, Tensor* input_grad, TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind, std::vector* weight_list_grad, const int& layer_idx, const int& gate_num) { const bool& is_bidirec = context.Attr("is_bidirec"); const int& time_step = input->dims()[0]; const int& batch_size = input->dims()[1]; const int& direction_num = is_bidirec ? 2 : 1; const int& hidden_size = context.Attr("hidden_size"); // split the output two tensor to output_forward, output_backward auto& device_ctx = context.template device_context(); pten::funcs::SetConstant zero; zero(device_ctx, input_grad, static_cast(0.0)); std::vector output_vec; Tensor forward_output; Tensor backward_output; std::vector forward_output_tensor_unbind; std::vector backward_output_tensor_unbind; // in the last layer, we will use the output as the last hidden // the output just the concat the forward hidden, backward hidden, so just // split it // in other layer, we just split the hidden in the rows output_vec.emplace_back(&forward_output); output_vec.emplace_back(&backward_output); split_tensor_at_last_dim(context, device_ctx, output, &output_vec, 2); forward_output_tensor_unbind = Unbind(*(output_vec[0])); backward_output_tensor_unbind = Unbind(*(output_vec[1])); std::vector output_grad_vec; Tensor grad_forward_output; Tensor grad_backward_output; output_grad_vec.emplace_back(&grad_forward_output); output_grad_vec.emplace_back(&grad_backward_output); split_tensor_at_last_dim(context, device_ctx, output_grad, &output_grad_vec, 2); auto forward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[0])); auto backward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[1])); // the gate_tensor and the grad_gate_tensor must be unbind auto layer_gate_tensor = gate_tensor_unbind[layer_idx]; layer_gate_tensor.Resize( {time_step * 2, batch_size, hidden_size * gate_num}); auto layer_forward_gate_tensor = layer_gate_tensor.Slice(0, time_step); auto layer_backward_gate_tensor = layer_gate_tensor.Slice(time_step, 2 * time_step); auto layer_forward_gate_tensor_unbind = Unbind(layer_forward_gate_tensor); auto layer_backward_gate_tensor_unbind = Unbind(layer_backward_gate_tensor); Tensor layer_grad_gate_tensor; layer_grad_gate_tensor.Resize(layer_gate_tensor.dims()); layer_grad_gate_tensor.mutable_data(context.GetPlace()); zero(device_ctx, &layer_grad_gate_tensor, static_cast(0.0)); auto layer_forward_grad_gate_tensor = layer_grad_gate_tensor.Slice(0, time_step); auto layer_backward_grad_gate_tensor = layer_grad_gate_tensor.Slice(time_step, 2 * time_step); auto layer_forward_grad_gate_tensor_unbind = Unbind(layer_forward_grad_gate_tensor); auto layer_backward_grad_gate_tensor_unbind = Unbind(layer_backward_grad_gate_tensor); Tensor layer_state_tensor; TensorList layer_state_tensor_unbind; if (state_tensor_unbind.size() > 0) { layer_state_tensor = state_tensor_unbind[layer_idx]; layer_state_tensor.Resize( {time_step * direction_num, batch_size, hidden_size}); layer_state_tensor_unbind = Unbind(layer_state_tensor); } Tensor layer_act_state_tensor; TensorList layer_act_state_tensor_unbind; if (act_state_tensor_unbind.size() > 0) { layer_act_state_tensor = act_state_tensor_unbind[layer_idx]; layer_act_state_tensor.Resize( {time_step * direction_num, batch_size, hidden_size}); layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor); } const bool& has_sequence_length = sequence_length == nullptr ? false : true; this->run_rnn_grad_function( context, device_ctx, input, input_grad, sequence_length, init_h_unbind, init_c_unbind, init_h_grad_unbind, init_c_grad_unbind, &layer_forward_grad_gate_tensor, &layer_forward_gate_tensor_unbind, &layer_forward_grad_gate_tensor_unbind, &layer_state_tensor_unbind, &layer_act_state_tensor_unbind, &forward_output_tensor_unbind, &forward_output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx, time_step, has_sequence_length, is_bidirec, false); this->run_rnn_grad_function( context, device_ctx, input, input_grad, sequence_length, init_h_unbind, init_c_unbind, init_h_grad_unbind, init_c_grad_unbind, &layer_backward_grad_gate_tensor, &layer_backward_gate_tensor_unbind, &layer_backward_grad_gate_tensor_unbind, &layer_state_tensor_unbind, &layer_act_state_tensor_unbind, &backward_output_tensor_unbind, &backward_output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx, time_step, has_sequence_length, is_bidirec, true); } }; template void backup_tensor(const framework::ExecutionContext& context, Tensor* dst, Tensor* src) { auto& device_ctx = context.template device_context(); dst->Resize(src->dims()); dst->mutable_data(context.GetPlace()); framework::TensorCopy(*src, device_ctx.GetPlace(), device_ctx, dst); } template struct GradCell { virtual ~GradCell() {} virtual void operator()(const framework::ExecutionContext& context, Tensor* gate_tensor, Tensor* state_tensor, Tensor* act_state_tensor, Tensor* hidden_tensor, const Tensor* weight_hh, Tensor* pre_hidden, Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, Tensor* grad_gate, Tensor* grad_weight_hh, Tensor* grad_pre_hidden, Tensor* grad_pre_state, Tensor* grad_bias_hh, const Tensor& mask_tensor, bool has_sequence_length) const {} void postprocess_pre_hidden_grad(const framework::ExecutionContext& context, Tensor* grad_pre_hidden, Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state, Tensor* grad_pre_state_bak, const Tensor& mask_tensor, bool has_sequence_length) const { if (has_sequence_length) { auto& place = *context.template device_context() .eigen_device(); auto mask = framework::EigenMatrix::From( mask_tensor, pten::make_ddim({mask_tensor.dims()[1], 1})); auto mask_broadcast = mask.broadcast(Eigen::DSizes(1, grad_pre_hidden->dims()[2])); auto pre_hidden_grad = framework::EigenMatrix::Reshape( *grad_pre_hidden, grad_pre_hidden->dims().size() - 1); auto pre_hidden_bak_grad = framework::EigenMatrix::Reshape( *grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1); pre_hidden_grad.device(place) = (1 - mask_broadcast) * pre_hidden_bak_grad + pre_hidden_grad * mask_broadcast; if (grad_pre_state) { auto pre_state_grad = framework::EigenMatrix::Reshape( *grad_pre_state, grad_pre_state->dims().size() - 1); auto pre_state_bak_grad = framework::EigenMatrix::Reshape( *grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1); pre_state_grad.device(place) = (1 - mask_broadcast) * pre_state_bak_grad + pre_state_grad * mask_broadcast; } } } virtual void update_pre_hidden_grad( const framework::ExecutionContext& context, Tensor* grad_gate, const Tensor* weight_hh, Tensor* grad_pre_hidden, Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state, Tensor* grad_pre_state_bak, const Tensor& mask_tensor, bool has_sequence_length) const { auto& device_ctx = context.template device_context(); auto blas = pten::funcs::GetBlas(device_ctx); Tensor* grad_gate_tmp = grad_gate; auto mat_dim_a = pten::funcs::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false); mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.batch_size_ = 0; auto mat_dim_b = pten::funcs::CreateMatrixDescriptor(weight_hh->dims(), 0, false); blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b, static_cast(1.0), grad_pre_hidden, 0); postprocess_pre_hidden_grad(context, grad_pre_hidden, grad_pre_hidden_bak, grad_pre_state, grad_pre_state_bak, mask_tensor, has_sequence_length); } virtual void update_weight_hh_grad(const framework::ExecutionContext& context, Tensor* grad_gate, Tensor* pre_hidden, Tensor* grad_weight_hh) const { auto& device_ctx = context.template device_context(); auto blas = pten::funcs::GetBlas(device_ctx); auto mat_dim_c = pten::funcs::CreateMatrixDescriptor(grad_gate->dims(), 0, true); mat_dim_c.height_ *= mat_dim_c.batch_size_; mat_dim_c.batch_size_ = 0; auto mat_dim_d = pten::funcs::CreateMatrixDescriptor(pre_hidden->dims(), 0, false); mat_dim_d.height_ *= mat_dim_d.batch_size_; mat_dim_d.batch_size_ = 0; blas.MatMul(*grad_gate, mat_dim_c, *pre_hidden, mat_dim_d, static_cast(1.0), grad_weight_hh, static_cast(1.0)); } }; template class EigenActivationBackwardFunctor> struct SimpleRNNGradCell : GradCell { void operator()(const framework::ExecutionContext& context, Tensor* gate_tensor, Tensor* state_tensor, Tensor* act_state_tensor, Tensor* hidden_tensor, const Tensor* weight_hh, Tensor* pre_hidden, Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, Tensor* grad_gate, Tensor* grad_weight_hh, Tensor* grad_pre_hidden, Tensor* grad_pre_state, Tensor* grad_bias_hh, const Tensor& mask_tensor, bool has_sequence_length) const override { auto& device_ctx = context.template device_context(); Tensor grad_pre_hidden_bak; if (has_sequence_length) { backup_tensor(context, &grad_pre_hidden_bak, grad_pre_hidden); } // h = act(z) // update dz auto dz = EigenVector::Flatten( GET_DATA_SAFELY(grad_gate, "Output", "dz", "Grad")); auto dh = EigenVector::Flatten( GET_DATA_SAFELY(grad_hidden, "Input", "dh", "Grad")); auto h = EigenVector::Flatten( GET_DATA_SAFELY(hidden_tensor, "Input", "h", "Value")); // useless, but need this argument to execute functor auto z = EigenVector::Flatten( GET_DATA_SAFELY(gate_tensor, "Input", "z", "Value")); auto* place = device_ctx.eigen_device(); EigenActivationBackwardFunctor functor; functor(*place, z, h, dh, dz); // update grad_weight_hh, grad_pre_hidden this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak, nullptr, nullptr, mask_tensor, has_sequence_length); this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh); } }; template struct GRUGradCell : GradCell { void operator()(const framework::ExecutionContext& context, Tensor* gate_tensor, Tensor* state_tensor, Tensor* act_state_tensor, Tensor* hidden_tensor, const Tensor* weight_hh, Tensor* pre_hidden, Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, Tensor* grad_gate, Tensor* grad_weight_hh, Tensor* grad_pre_hidden, Tensor* grad_pre_state, Tensor* grad_bias_hh, const Tensor& mask_tensor, bool has_sequence_length) const override { auto& device_ctx = context.template device_context(); size_t frame_size = pre_hidden->dims()[2]; size_t batch_size = pre_hidden->dims()[1]; Tensor grad_pre_hidden_bak; if (has_sequence_length) { backup_tensor(context, &grad_pre_hidden_bak, grad_pre_hidden); } // zero pre_hidden pten::funcs::SetConstant zero; zero(device_ctx, grad_pre_hidden, static_cast(0.0)); math::GRUMetaValue gru_value; math::GRUMetaGrad gru_grad; gru_value.gate_value = gate_tensor->data(); gru_value.prev_out_value = pre_hidden->data(); gru_value.reset_output_value = state_tensor->data(); gru_value.state_weight = weight_hh->data() + 2 * frame_size * frame_size; gru_value.gate_weight = weight_hh->data(); gru_grad.gate_grad = grad_gate->data(); gru_grad.reset_output_grad = grad_state->data(); gru_grad.prev_out_grad = grad_pre_hidden->data(); gru_grad.output_grad = grad_hidden->data(); gru_grad.gate_weight_grad = grad_weight_hh->data(); gru_grad.state_weight_grad = grad_weight_hh->data() + 2 * frame_size * frame_size; gru_grad.bias_hh_grad = grad_bias_hh->data(); auto act_gate = math::detail::GetActivationType("sigmoid_v2"); auto act_node = math::detail::GetActivationType("tanh_v2"); math::GRUUnitGradFunctorV2::compute( device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node, act_gate); this->postprocess_pre_hidden_grad(context, grad_pre_hidden, &grad_pre_hidden_bak, nullptr, nullptr, mask_tensor, has_sequence_length); } }; template struct LSTMGradCell : GradCell { void operator()(const framework::ExecutionContext& context, Tensor* gate_tensor, Tensor* state_tensor, Tensor* act_state_tensor, Tensor* hidden_tensor, const Tensor* weight_hh, Tensor* pre_hidden, Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, Tensor* grad_gate, Tensor* grad_weight_hh, Tensor* grad_pre_hidden, Tensor* grad_pre_state, Tensor* grad_bias_hh, const Tensor& mask_tensor, bool has_sequence_length) const override { auto& device_ctx = context.template device_context(); size_t frame_size = state_tensor->dims()[2]; size_t batch_size = state_tensor->dims()[1]; Tensor grad_pre_hidden_bak; Tensor grad_pre_state_bak; if (has_sequence_length) { backup_tensor(context, &grad_pre_hidden_bak, grad_pre_hidden); backup_tensor(context, &grad_pre_state_bak, grad_pre_state); } math::LstmMetaValue lstm_value; math::LstmMetaGrad lstm_grad; create_lstm_value(&lstm_value); create_lstm_grad(&lstm_grad); lstm_value.gate_value = gate_tensor->data(); lstm_value.state_value = state_tensor->data(); lstm_value.state_active_value = act_state_tensor->data(); lstm_value.prev_state_value = pre_state->data(); lstm_grad.state_grad = grad_state->data(); lstm_grad.gate_grad = grad_gate->data(); lstm_grad.output_grad = grad_hidden->data(); lstm_grad.prev_state_grad = grad_pre_state->data(); lstm_value.output_value = nullptr; lstm_grad.state_active_grad = nullptr; auto gate_act = math::detail::GetActivationType("sigmoid_v2"); auto state_act = math::detail::GetActivationType("tanh_v2"); auto cand_act = math::detail::GetActivationType("tanh_v2"); T cell_clip = 0.0; math::LstmUnitGradFunctor::compute( device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip, gate_act, state_act, cand_act, false); this->update_pre_hidden_grad( context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak, grad_pre_state, &grad_pre_state_bak, mask_tensor, has_sequence_length); this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh); } }; template class SingleGradLayerT, template class BidirGradLayerT, typename T> void RnnGradFunc(const framework::ExecutionContext& context, const int& gate_num) { // get the tensor pointer for the input auto* input = context.Input("Input"); auto weight_list = context.MultiInput("WeightList"); auto pre_state = context.MultiInput("PreState"); const Tensor* init_h = pre_state[0]; const Tensor* init_c = nullptr; if (is_lstm(context)) { init_c = pre_state[1]; } auto* reserve_state = context.Input("Reserve"); auto* dropout_state = context.Input("DropoutState"); auto* output = context.Input("Out"); auto* output_grad = context.Input(framework::GradVarName("Out")); auto state_grad = context.MultiInput(framework::GradVarName("State")); const Tensor* last_h_grad = state_grad[0]; const Tensor* last_c_grad = nullptr; if (is_lstm(context)) { last_c_grad = state_grad[1]; } bool has_seq_length = context.HasInput("SequenceLength"); const Tensor* sequence_length = nullptr; if (has_seq_length) { sequence_length = context.Input("SequenceLength"); } // get the tensor pointer for the output auto* input_grad = context.Output(framework::GradVarName("Input")); auto weight_grad_list = context.MultiOutput( framework::GradVarName("WeightList")); auto pre_state_grad = context.MultiOutput(framework::GradVarName("PreState")); Tensor* init_h_grad = nullptr; Tensor* init_c_grad = nullptr; if (pre_state_grad.size() > 0) { // has gradient init_h_grad = pre_state_grad[0]; if (is_lstm(context)) { init_c_grad = pre_state_grad[1]; } } // get the attributes for the calcluate const int& num_layers = context.Attr("num_layers"); const bool& is_bidirec = context.Attr("is_bidirec"); const float& dropout_prob = context.Attr("dropout_prob"); bool is_test = context.HasAttr("is_test") ? context.Attr("is_test") : false; // get the input_size, batch_size, time_step, hidden_size const int& time_step = input->dims()[0]; const int& batch_size = input->dims()[1]; const int& hidden_size = context.Attr("hidden_size"); const int& direction_num = is_bidirec ? 2 : 1; // allocate the memory and initization the input_grad Tensor input_grad_value; if (!input_grad) { input_grad = &input_grad_value; } input_grad->mutable_data(input->dims(), context.GetPlace()); if (init_h_grad) { init_h_grad->mutable_data(init_h->dims(), context.GetPlace()); } if (init_c_grad) { init_c_grad->mutable_data(init_c->dims(), context.GetPlace()); } // reset the parameter to sorted order and allocate the memory std::vector parameter_lists; parameter_lists.reserve(num_layers); reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec, ¶meter_lists); for (unsigned int i = 0; i < weight_grad_list.size(); ++i) { weight_grad_list[i]->mutable_data(context.GetPlace()); } std::vector parameter_lists_grad; parameter_lists_grad.reserve(num_layers); reset_parameter_vector(weight_grad_list, num_layers, gate_num, is_bidirec, ¶meter_lists_grad); // resolve the state of reverse_state Tensor gate_tensor; Tensor state_tensor; Tensor act_state_tensor; Tensor hidden_tensor; SplitReserveData(context, reserve_state, &gate_tensor, &state_tensor, &act_state_tensor, &hidden_tensor, direction_num, time_step, batch_size, hidden_size, gate_num, num_layers); int gate_num_tmp = gate_num; if (gate_num == 0) { gate_num_tmp = 1; } gate_tensor.Resize({num_layers, time_step * direction_num, batch_size, hidden_size * gate_num_tmp}); if (state_tensor.numel() > 0) { state_tensor.Resize( {num_layers, time_step * direction_num, batch_size, hidden_size}); } if (act_state_tensor.numel() > 0) { act_state_tensor.Resize( {num_layers, time_step * direction_num, batch_size, hidden_size}); } if (num_layers > 1) { hidden_tensor.Resize( {num_layers - 1, time_step, batch_size, hidden_size * direction_num}); } // unbind auto last_h_grad_unbind = Unbind(*last_h_grad); auto gate_tensor_unbind = Unbind(gate_tensor); TensorList last_c_grad_unbind; if (last_c_grad) { last_c_grad_unbind = Unbind(*last_c_grad); } TensorList init_h_unbind, init_c_unbind; TensorList init_h_grad_unbind, init_c_grad_unbind; TensorList state_tensor_unbind, act_state_tensor_unbind; TensorList hidden_tensor_unbind; init_h_unbind = Unbind(*init_h); if (init_c) { init_c_unbind = Unbind(*init_c); } if (init_h_grad != nullptr) { init_h_grad_unbind = Unbind(*init_h_grad); } if (init_c_grad != nullptr) { init_c_grad_unbind = Unbind(*init_c_grad); } if (state_tensor.numel() > 0) { state_tensor_unbind = Unbind(state_tensor); } if (act_state_tensor.numel() > 0) { act_state_tensor_unbind = Unbind(act_state_tensor); } if (num_layers > 1) { hidden_tensor_unbind = Unbind(hidden_tensor); } // squeeze the hidden first dim for (unsigned int i = 0; i < hidden_tensor_unbind.size(); i++) { hidden_tensor_unbind[i].Resize( pten::slice_ddim(hidden_tensor_unbind[i].dims(), 1, hidden_tensor_unbind[i].dims().size())); } // add the output tensor to the hidden vector Tensor tmp; hidden_tensor_unbind.emplace_back(tmp); hidden_tensor_unbind[num_layers - 1].ShareDataWith(*output); GradCellType cell; Tensor layer_input; Tensor layer_output; Tensor* layer_input_grad_holder = nullptr; Tensor tmp_out; tmp_out.ShareDataWith(*output_grad); Tensor* layer_output_grad_holder = &tmp_out; Tensor input_grad_temp; Tensor output_grad_temp; bool has_allocate_mem = false; for (int i = num_layers - 1; i >= 0; --i) { // the layer input output had saved, just use the data if (i > 0) { if (layer_input.numel() == 0) { layer_input.Resize(hidden_tensor_unbind[i - 1].dims()); layer_input.mutable_data(context.GetPlace()); } dropout_helper(context, &hidden_tensor_unbind[i - 1], &layer_input, dropout_state, dropout_prob); } else { layer_input.ShareDataWith(*input); } layer_output.ShareDataWith(hidden_tensor_unbind[i]); if (num_layers == 1) { layer_input_grad_holder = input_grad; } else { if (i == num_layers - 1) { input_grad_temp.Resize(layer_input.dims()); input_grad_temp.mutable_data(context.GetPlace()); layer_input_grad_holder = &input_grad_temp; } } if (is_bidirec) { BidirGradLayerT layer(cell); layer(context, &layer_input, &layer_output, &init_h_unbind, &init_c_unbind, last_h_grad_unbind, last_c_grad_unbind, gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind, layer_output_grad_holder, parameter_lists, sequence_length, layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind, ¶meter_lists_grad, i, gate_num_tmp); } else { SingleGradLayerT layer(cell); layer(context, &layer_input, &layer_output, &init_h_unbind, &init_c_unbind, last_h_grad_unbind, last_c_grad_unbind, gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind, layer_output_grad_holder, parameter_lists, sequence_length, layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind, ¶meter_lists_grad, i, gate_num_tmp); } // calcluate the dropout gradient for the layer_input_grad_holder // dropout_state save in the forward process if (i > 0) { if ((!is_test) && (dropout_prob != 0)) { dropout_cpu_grad_function_inplace(context, layer_input_grad_holder, dropout_state, dropout_prob); } } if (i - 1 == 0) { layer_output_grad_holder = input_grad; } else { if (!has_allocate_mem) { output_grad_temp.Resize(layer_input_grad_holder->dims()); output_grad_temp.mutable_data(context.GetPlace()); layer_output_grad_holder = &output_grad_temp; has_allocate_mem = true; } } SwapPoniter(&layer_input_grad_holder, &layer_output_grad_holder); } } template class RNNCPUGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { int gate_num = 4; if (is_lstm(ctx)) { RnnGradFunc, SingleGradLayer, BidirGradLayer, T>( ctx, gate_num); } else if (is_gru(ctx)) { gate_num = 3; RnnGradFunc, SingleGradLayer, BidirGradLayer, T>(ctx, gate_num); // run gru } else if (is_rnn_relu(ctx)) { gate_num = 1; RnnGradFunc, SingleGradLayer, BidirGradLayer, T>(ctx, gate_num); // run rnn } else if (is_rnn_tanh(ctx)) { gate_num = 1; RnnGradFunc, SingleGradLayer, BidirGradLayer, T>(ctx, gate_num); } } }; } // namespace operators } // namespace paddle