From 9362d85e0ef9afb0fcd36e12d0a4eac92f08265f Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Fri, 20 Nov 2020 12:58:42 +0800 Subject: [PATCH] Add LSTM, Simple RNN and GRU CPU kernel (#28577) * add lstm, simple rnn op kernel * fix the test_lstm for the rnn op * change func name * fix forward postprocess bug * add gru forward, backward code * remove unittest.skipIf; use a big rnn op instead of combination op * fix input doesn't have gradient bug * add eigen lstm forward, backward Co-authored-by: wawltor --- .../math/detail/activation_functions.h | 53 +- .../operators/math/detail/avx_functions.cc | 15 + .../operators/math/detail/gru_cpu_kernel.h | 282 ++- .../operators/math/detail/gru_gpu_kernel.h | 31 +- .../fluid/operators/math/detail/gru_kernel.h | 79 +- .../operators/math/detail/lstm_cpu_kernel.h | 228 +- paddle/fluid/operators/math/gru_compute.cc | 54 + paddle/fluid/operators/math/gru_compute.h | 24 +- paddle/fluid/operators/math/lstm_compute.cc | 17 +- paddle/fluid/operators/math/lstm_compute.cu | 6 +- paddle/fluid/operators/math/lstm_compute.h | 8 +- paddle/fluid/operators/rnn_op.cc | 10 +- paddle/fluid/operators/rnn_op.cu.cc | 6 + paddle/fluid/operators/rnn_op.h | 2085 +++++++++++++++++ .../unittests/dygraph_to_static/test_lstm.py | 26 +- .../fluid/tests/unittests/rnn/convert.py | 31 + .../fluid/tests/unittests/rnn/rnn_numpy.py | 103 +- .../fluid/tests/unittests/test_gru_rnn_op.py | 164 ++ .../fluid/tests/unittests/test_rnn_op.py | 159 ++ .../tests/unittests/test_simple_rnn_op.py | 162 ++ .../white_list/check_shape_white_list.py | 1 + .../white_list/no_check_set_white_list.py | 1 + .../white_list/op_threshold_white_list.py | 3 +- python/paddle/nn/layer/rnn.py | 3 +- 24 files changed, 3376 insertions(+), 175 deletions(-) create mode 100644 paddle/fluid/operators/rnn_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_gru_rnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_rnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_simple_rnn_op.py diff --git a/paddle/fluid/operators/math/detail/activation_functions.h b/paddle/fluid/operators/math/detail/activation_functions.h index 5476b1a2d3..883ddec8fa 100644 --- a/paddle/fluid/operators/math/detail/activation_functions.h +++ b/paddle/fluid/operators/math/detail/activation_functions.h @@ -30,18 +30,24 @@ namespace detail { enum ActivationType { kSigmoid, + KSigmoidV2, kReLU, kTanh, + kTanhV2, kIdentity, }; inline ActivationType GetActivationType(const std::string &type) { if (type == "sigmoid") { return ActivationType::kSigmoid; + } else if (type == "sigmoid_v2") { + return ActivationType::KSigmoidV2; } else if (type == "relu") { return ActivationType::kReLU; } else if (type == "tanh") { return ActivationType::kTanh; + } else if (type == "tanh_v2") { + return ActivationType::kTanhV2; } else if (type == "identity" || type == "") { return ActivationType::kIdentity; } @@ -68,6 +74,14 @@ DEVICE T Sigmoid(const T a) { return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); } +/* + * Don't limit input in a threshold range. + */ +template +DEVICE T SigmoidV2(const T a) { + return static_cast(1.0) / (static_cast(1.0) + exp(-a)); +} + template DEVICE T Tanh(const T a) { T tmp = -2.0 * a; @@ -75,6 +89,15 @@ DEVICE T Tanh(const T a) { return (2.0 / (1.0 + exp(tmp))) - 1.0; } +/* + * Don't limit input in a threshold range. + */ +template +DEVICE T TanhV2(const T a) { + T tmp = -2.0 * a; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + } // namespace forward namespace backward { @@ -108,20 +131,24 @@ struct Active { }; static DEVICE Active::Act kActFloat[] = { - &forward::Sigmoid, &forward::Relu, &forward::Tanh, - &forward::Identity}; + &forward::Sigmoid, &forward::SigmoidV2, + &forward::Relu, &forward::Tanh, + &forward::TanhV2, &forward::Identity}; static DEVICE Active::ActGrad kActGradFloat[] = { - &backward::Sigmoid, &backward::Relu, &backward::Tanh, - &backward::Identity}; + &backward::Sigmoid, &backward::Sigmoid, + &backward::Relu, &backward::Tanh, + &backward::Tanh, &backward::Identity}; static DEVICE Active::Act kActDouble[] = { - &forward::Sigmoid, &forward::Relu, &forward::Tanh, - &forward::Identity}; + &forward::Sigmoid, &forward::SigmoidV2, + &forward::Relu, &forward::Tanh, + &forward::TanhV2, &forward::Identity}; static DEVICE Active::ActGrad kActGradDouble[] = { - &backward::Sigmoid, &backward::Relu, - &backward::Tanh, &backward::Identity}; + &backward::Sigmoid, &backward::Sigmoid, + &backward::Relu, &backward::Tanh, + &backward::Tanh, &backward::Identity}; namespace forward { inline DEVICE float activation(float a, int index) { @@ -149,7 +176,9 @@ namespace forward { namespace avx { __m256 Relu(const __m256 a); __m256 Sigmoid(const __m256 a); +__m256 SigmoidV2(const __m256 a); __m256 Tanh(const __m256 a); +__m256 TanhV2(const __m256 a); __m256 Identity(const __m256 a); } // namespace avx } // namespace forward @@ -164,12 +193,12 @@ __m256 Identity(const __m256 a, const __m256 b); } // namespace backward static Active<__m256>::Act kActAvx[] = { - &forward::avx::Sigmoid, &forward::avx::Relu, &forward::avx::Tanh, - &forward::avx::Identity}; + &forward::avx::Sigmoid, &forward::avx::SigmoidV2, &forward::avx::Relu, + &forward::avx::Tanh, &forward::avx::TanhV2, &forward::avx::Identity}; static Active<__m256>::ActGrad kActGradAvx[] = { - &backward::avx::Sigmoid, &backward::avx::Relu, &backward::avx::Tanh, - &backward::avx::Identity}; + &backward::avx::Sigmoid, &backward::avx::Sigmoid, &backward::avx::Relu, + &backward::avx::Tanh, &backward::avx::Tanh, &backward::avx::Identity}; namespace forward { inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } diff --git a/paddle/fluid/operators/math/detail/avx_functions.cc b/paddle/fluid/operators/math/detail/avx_functions.cc index 022ffc5337..89e2c825c2 100644 --- a/paddle/fluid/operators/math/detail/avx_functions.cc +++ b/paddle/fluid/operators/math/detail/avx_functions.cc @@ -43,6 +43,13 @@ __m256 Sigmoid(const __m256 a) { return tmp; } +__m256 SigmoidV2(const __m256 a) { + __m256 tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), a); + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), exp256_ps(tmp)); + tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); + return tmp; +} + __m256 Tanh(const __m256 a) { __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); @@ -53,6 +60,14 @@ __m256 Tanh(const __m256 a) { _mm256_set1_ps(1.0f)); } +__m256 TanhV2(const __m256 a) { + __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); + return _mm256_sub_ps( + _mm256_div_ps(_mm256_set1_ps(2.0f), + _mm256_add_ps(_mm256_set1_ps(1.0f), exp256_ps(tmp))), + _mm256_set1_ps(1.0f)); +} + __m256 Identity(const __m256 a) { return a; } } // namespace avx diff --git a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h index c6dd972e12..e05a5190e8 100644 --- a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h @@ -25,26 +25,38 @@ namespace detail { #ifndef __NVCC__ template -void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, - T *gate_value, T *reset_output_value, - T *prev_output_value, int frame_size, - ActivationType active_gate) { +void hl_naive_gru_forward_reset_output( + OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, + const T *prev_output_value, int frame_size, ActivationType active_gate, + bool old_version = true, const T *reset_bias = nullptr) { T r_value_update_gate; T r_value_reset_gate; T r_value_reset_output; T r_prev_out = 0; - T *update_gate = gate_value; - T *reset_gate = gate_value + frame_size; - + T r_reset_bias = 0; + T *update_gate = nullptr; + T *reset_gate = nullptr; + if (old_version) { + update_gate = gate_value; + reset_gate = gate_value + frame_size; + } else { + reset_gate = gate_value; + update_gate = gate_value + frame_size; + } for (int i = 0; i < frame_size; i++) { r_value_update_gate = update_gate[i]; r_value_reset_gate = reset_gate[i]; + if (!old_version) { + r_value_reset_output = reset_output_value[i]; + r_reset_bias = reset_bias[i]; + } if (prev_output_value) { r_prev_out = prev_output_value[i]; } op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, - &r_value_reset_output, active_gate); + &r_value_reset_output, active_gate, &r_reset_bias, + old_version); update_gate[i] = r_value_update_gate; reset_gate[i] = r_value_reset_gate; @@ -53,16 +65,20 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, } template -void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, - T *gate_value, T *prev_output_value, - T *output_value, int frame_size, - ActivationType active_node, - bool origin_mode) { +void hl_naive_gru_forward_final_output( + OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value, + T *output_value, int frame_size, ActivationType active_node, + bool origin_mode, bool old_version = true) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; T r_output; - T *update_gate = gate_value; + T *update_gate; + if (old_version) { + update_gate = gate_value; + } else { + update_gate = gate_value + frame_size; + } T *frame_state = gate_value + frame_size * 2; for (int i = 0; i < frame_size; i++) { @@ -83,16 +99,26 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, template void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, - T *prev_output_value, int frame_size, - ActivationType active_gate) { + const T *prev_output_value, int frame_size, + ActivationType active_gate, + bool old_version = true, + const T *reset_bias = nullptr) { #ifdef __AVX__ __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_reset_output; __m256 r_prev_out = _mm256_set1_ps(0.0f), r_prev_out_last = _mm256_set1_ps(0.0f); - T *update_gate = gate_value; - T *reset_gate = gate_value + frame_size; + __m256 r_reset_bias = _mm256_set1_ps(0.0f); + T *update_gate; + T *reset_gate; + if (old_version) { + update_gate = gate_value; + reset_gate = gate_value + frame_size; + } else { + reset_gate = gate_value; + update_gate = gate_value + frame_size; + } int block = 8; const int n = frame_size; const int rest = n % block; @@ -115,9 +141,15 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, if (prev_output_value) { r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i)); } + if (!old_version) { + r_reset_bias = _mm256_loadu_ps((const float *)(reset_bias + i)); + r_value_reset_output = + _mm256_loadu_ps((const float *)(reset_output_value + i)); + } op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, - &r_value_reset_output, active_gate); + &r_value_reset_output, active_gate, &r_reset_bias, + old_version); _mm256_storeu_ps(reinterpret_cast(update_gate + i), r_value_update_gate); @@ -131,7 +163,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, i = n - block; op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last, - &r_prev_out_last, &r_value_reset_output, active_gate); + &r_prev_out_last, &r_value_reset_output, active_gate, + &r_reset_bias, old_version); _mm256_storeu_ps(reinterpret_cast(update_gate + i), r_value_update_gate_last); @@ -145,17 +178,24 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, template void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, - T *gate_value, T *prev_output_value, + T *gate_value, const T *prev_output_value, T *output_value, int frame_size, ActivationType active_node, - bool origin_mode) { + bool origin_mode, + bool old_version = true) { #ifdef __AVX__ __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); __m256 r_prev_out = _mm256_set1_ps(0.0f), r_prev_out_last = _mm256_set1_ps(0.0f); __m256 r_output; - T *update_gate = gate_value; + T *update_gate; + if (old_version) { + update_gate = gate_value; + } else { + update_gate = gate_value + frame_size; + } + T *frame_state = gate_value + frame_size * 2; int block = 8; const int n = frame_size; @@ -205,19 +245,21 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, template inline void forward_reset_output(OpResetOutput op_reset_output, GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_gate) { + int batch_size, ActivationType active_gate, + bool old_version = true) { for (int b = 0; b < batch_size; b++) { if (OpResetOutput::avx && (frame_size > static_cast(8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_reset_output( op_reset_output, value.gate_value, value.reset_output_value, - value.prev_out_value, frame_size, active_gate); + value.prev_out_value, frame_size, active_gate, old_version, + value.reset_bias); } else { hl_naive_gru_forward_reset_output( op_reset_output, value.gate_value, value.reset_output_value, - value.prev_out_value, frame_size, active_gate); + value.prev_out_value, frame_size, active_gate, old_version, + value.reset_bias); } - value.gate_value += frame_size * 3; value.reset_output_value += frame_size; if (value.prev_out_value) { @@ -230,17 +272,19 @@ template inline void forward_final_output(OpFinalOutput op_final_output, GRUMetaValue value, int frame_size, int batch_size, ActivationType active_node, - bool origin_mode) { + bool origin_mode, bool old_version = true) { for (int b = 0; b < batch_size; b++) { if (OpFinalOutput::avx && (frame_size > static_cast(8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_final_output(op_final_output, value.gate_value, value.prev_out_value, value.output_value, - frame_size, active_node, origin_mode); + frame_size, active_node, origin_mode, + old_version); } else { - hl_naive_gru_forward_final_output( - op_final_output, value.gate_value, value.prev_out_value, - value.output_value, frame_size, active_node, origin_mode); + hl_naive_gru_forward_final_output(op_final_output, value.gate_value, + value.prev_out_value, + value.output_value, frame_size, + active_node, origin_mode, old_version); } value.gate_value += frame_size * 3; @@ -253,7 +297,7 @@ inline void forward_final_output(OpFinalOutput op_final_output, template void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, - T *gate_grad, T *prev_out_value, + T *gate_grad, const T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, ActivationType active_node, @@ -295,7 +339,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, template void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, - T *gate_grad, T *prev_out_value, + T *gate_grad, const T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, ActivationType active_gate) { @@ -340,7 +384,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, template void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, - T *gate_grad, T *prev_out_value, + T *gate_grad, const T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, ActivationType active_node, bool origin_mode) { @@ -364,7 +408,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, r_frame_state_value = frame_state_value[i]; r_out_grad = (reinterpret_cast<__m256 *>(output_grad))[i]; if (prev_out_value) { - r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i]; + r_prev_out_value = (reinterpret_cast(prev_out_value))[i]; } if (prev_out_grad) { r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; @@ -385,7 +429,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, template void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, - T *gate_grad, T *prev_out_value, + T *gate_grad, const T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, ActivationType active_gate) { @@ -412,7 +456,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, r_reset_output_grad = (reinterpret_cast<__m256 *>(reset_output_grad))[i]; } if (prev_out_value) { - r_prev_out_value = (reinterpret_cast<__m256 *>(prev_out_value))[i]; + r_prev_out_value = (reinterpret_cast(prev_out_value))[i]; } if (prev_out_grad) { r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; @@ -431,6 +475,135 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, #endif } +template +inline void hl_naive_gru_backward(OpGruGrad op_gru_grad, T *gate_value, + T *gate_grad, const T *prev_out_value, + T *prev_out_grad, T *reset_output_value, + T *reset_output_grad, T *output_grad, + int frame_size, ActivationType active_node, + ActivationType active_gate) { + T r_value_reset_gate; + T r_grad_reset_gate; + T r_value_update_gate; + T r_grad_update_gate; + T r_value_frame_state; + T r_grad_frame_state; + T r_value_prev_out = 0; + T r_grad_prev_out = 0; + T r_grad_output; + T r_value_reset_output; + T r_grad_reset_output = 0; + T *reset_gate_value = gate_value; + T *reset_gate_grad = gate_grad; + T *update_gate_value = gate_value + frame_size; + T *update_gate_grad = gate_grad + frame_size; + T *frame_state_value = gate_value + 2 * frame_size; + T *frame_state_grad = gate_grad + 2 * frame_size; + + for (int i = 0; i < frame_size; ++i) { + r_value_reset_gate = reset_gate_value[i]; + r_grad_reset_gate = reset_gate_grad[i]; + r_value_update_gate = update_gate_value[i]; + r_grad_update_gate = update_gate_grad[i]; + r_value_frame_state = frame_state_value[i]; + r_grad_frame_state = frame_state_grad[i]; + if (prev_out_value) { + r_value_prev_out = prev_out_value[i]; + } + if (prev_out_grad) { + r_grad_prev_out = prev_out_grad[i]; + } + r_grad_output = output_grad[i]; + r_value_reset_output = reset_output_value[i]; + if (prev_out_value && prev_out_grad) { + r_grad_reset_output = reset_output_grad[i]; + } + + op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate, + &r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state, + &r_value_prev_out, &r_grad_prev_out, &r_grad_output, + &r_value_reset_output, &r_grad_reset_output, active_node, + active_gate); + + reset_gate_grad[i] = r_grad_reset_gate; + update_gate_grad[i] = r_grad_update_gate; + frame_state_grad[i] = r_grad_frame_state; + if (prev_out_grad) { + prev_out_grad[i] = r_grad_prev_out; + } + if (prev_out_value && prev_out_grad) { + reset_output_grad[i] = r_grad_reset_output; + } + } +} + +template +inline void hl_avx_gru_backward(OpGruGrad op_gru_grad, T *gate_value, + T *gate_grad, const T *prev_out_value, + T *prev_out_grad, T *reset_output_value, + T *reset_output_grad, T *output_grad, + int frame_size, ActivationType active_node, + ActivationType active_gate) { +#ifdef __AVX__ + __m256 r_value_reset_gate; + __m256 r_grad_reset_gate; + __m256 r_value_update_gate; + __m256 r_grad_update_gate; + __m256 r_value_frame_state; + __m256 r_grad_frame_state; + __m256 r_value_prev_out = _mm256_set1_ps(0.0f); + __m256 r_grad_prev_out = _mm256_set1_ps(0.0f); + __m256 r_grad_output; + __m256 r_value_reset_output; + __m256 r_grad_reset_output = _mm256_set1_ps(0.0f); + __m256 *reset_gate_value = reinterpret_cast<__m256 *>(gate_value); + __m256 *reset_gate_grad = reinterpret_cast<__m256 *>(gate_grad); + __m256 *update_gate_value = + reinterpret_cast<__m256 *>(gate_value + frame_size); + __m256 *update_gate_grad = reinterpret_cast<__m256 *>(gate_grad + frame_size); + __m256 *frame_state_value = + reinterpret_cast<__m256 *>(gate_value + 2 * frame_size); + __m256 *frame_state_grad = + reinterpret_cast<__m256 *>(gate_grad + 2 * frame_size); + + for (int i = 0; i < frame_size / 8; ++i) { + r_value_reset_gate = reset_gate_value[i]; + r_grad_reset_gate = reset_gate_grad[i]; + r_value_update_gate = update_gate_value[i]; + r_grad_update_gate = update_gate_grad[i]; + r_value_frame_state = frame_state_value[i]; + r_grad_frame_state = frame_state_grad[i]; + if (prev_out_value) { + r_value_prev_out = (reinterpret_cast(prev_out_value))[i]; + } + if (prev_out_grad) { + r_grad_prev_out = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; + } + r_grad_output = (reinterpret_cast<__m256 *>(output_grad))[i]; + r_value_reset_output = (reinterpret_cast<__m256 *>(reset_output_value))[i]; + if (prev_out_value && prev_out_grad) { + r_grad_reset_output = (reinterpret_cast<__m256 *>(reset_output_grad))[i]; + } + + op_gru_grad(&r_value_reset_gate, &r_grad_reset_gate, &r_value_update_gate, + &r_grad_update_gate, &r_value_frame_state, &r_grad_frame_state, + &r_value_prev_out, &r_grad_prev_out, &r_grad_output, + &r_value_reset_output, &r_grad_reset_output, active_node, + active_gate); + + reset_gate_grad[i] = r_grad_reset_gate; + update_gate_grad[i] = r_grad_update_gate; + frame_state_grad[i] = r_grad_frame_state; + if (prev_out_grad) { + (reinterpret_cast<__m256 *>(prev_out_grad))[i] = r_grad_prev_out; + } + if (prev_out_value && prev_out_grad) { + (reinterpret_cast<__m256 *>(reset_output_grad))[i] = r_grad_reset_output; + } + } +#endif +} + template inline void backward_state_grad(OpStateGrad op_state_grad, GRUMetaValue value, GRUMetaGrad grad, @@ -491,6 +664,39 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, } } +template +inline void cpu_gru_backward(OpGruGrad op_gru_grad, GRUMetaValue value, + GRUMetaGrad grad, int frame_size, + int batch_size, ActivationType active_node, + ActivationType active_gate) { + for (int b = 0; b < batch_size; ++b) { + if (OpGruGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { + hl_avx_gru_backward( + op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad, + grad.output_grad, frame_size, active_node, active_gate); + } else { + hl_naive_gru_backward( + op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value, + grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad, + grad.output_grad, frame_size, active_node, active_gate); + } + + value.gate_value += frame_size * 3; + value.reset_output_value += frame_size; + if (value.prev_out_value) { + value.prev_out_value += frame_size; + } + + grad.gate_grad += frame_size * 3; + grad.output_grad += frame_size; + grad.reset_output_grad += frame_size; + if (grad.prev_out_grad) { + grad.prev_out_grad += frame_size; + } + } +} + #endif } // namespace detail diff --git a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h index 77d7ff57cd..62c45f4dc0 100644 --- a/paddle/fluid/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_gpu_kernel.h @@ -31,8 +31,8 @@ namespace detail { template __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, - T *prev_output_value, int frame_size, - int batch_size, + const T *prev_output_value, + int frame_size, int batch_size, ActivationType active_gate) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; @@ -68,12 +68,10 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, * grid(frame_blocks, batch_blocks) */ template -__global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, - T *gate_value, T *prev_output_value, - T *output_value, int frame_size, - int batch_size, - ActivationType active_node, - bool origin_mode) { +__global__ void KeGruForwardFinalOutput( + OpFinalOutput op_final_output, T *gate_value, const T *prev_output_value, + T *output_value, int frame_size, int batch_size, ActivationType active_node, + bool origin_mode) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -106,8 +104,9 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, * grid(frame_blocks, 1) */ template -__global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, - T *gate_weight, T *reset_output, +__global__ void KeFastCollectiveGruGate(T *gate_value, + const T *prev_output_value, + const T *gate_weight, T *reset_output, int frame_size, ActivationType active_node) { T xt_0 = 0.0f; @@ -164,10 +163,10 @@ __global__ void KeFastCollectiveGruGate(T *gate_value, T *prev_output_value, * grid(frame_blocks, 1) */ template -__global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, - T *output_value, T *gate_value, - T *reset_value, int frame_size, - ActivationType act_node, +__global__ void KeFastCollectiveGruOut(const T *gate_weight, + const T *prev_out_value, T *output_value, + T *gate_value, T *reset_value, + int frame_size, ActivationType act_node, bool origin_mode) { int COL = blockIdx.x * blockDim.x + threadIdx.x; @@ -223,7 +222,7 @@ __global__ void KeFastCollectiveGruOut(T *gate_weight, T *prev_out_value, */ template __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, - T *gate_grad, T *prev_out_value, + T *gate_grad, const T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, int batch_size, ActivationType active_node, @@ -272,7 +271,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, */ template __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, - T *gate_grad, T *prev_out_value, + T *gate_grad, const T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, int batch_size, ActivationType active_gate) { diff --git a/paddle/fluid/operators/math/detail/gru_kernel.h b/paddle/fluid/operators/math/detail/gru_kernel.h index 894f5f04d2..faa4a6a06e 100644 --- a/paddle/fluid/operators/math/detail/gru_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_kernel.h @@ -30,10 +30,17 @@ class gru_resetOutput { public: HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate, T *prev_out, T *value_reset_output, - ActivationType act_gate) { + ActivationType act_gate, + T *value_reset_bias = nullptr, + bool old_version = true) { *value_update_gate = activation(*value_update_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate); - *value_reset_output = (*prev_out) * (*value_reset_gate); + if (old_version) { + *value_reset_output = (*prev_out) * (*value_reset_gate); + } else { + *value_reset_output = + (*value_reset_output + *value_reset_bias) * (*value_reset_gate); + } } #ifndef __NVCC__ #ifndef __AVX__ @@ -43,10 +50,19 @@ class gru_resetOutput { HOSTDEVICE void operator()(__m256 *value_update_gate, __m256 *value_reset_gate, __m256 *prev_out, __m256 *value_reset_output, - ActivationType act_gate) { + ActivationType act_gate, + __m256 *value_reset_bias = nullptr, + bool old_version = true) { *value_update_gate = activation(*value_update_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate); - *value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate); + if (old_version) { + *value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate); + } else { + *value_reset_output = + _mm256_add_ps(*value_reset_output, *value_reset_bias); + *value_reset_output = + _mm256_mul_ps(*value_reset_output, *value_reset_gate); + } } #endif #endif @@ -192,6 +208,61 @@ class gru_resetGrad { #endif #endif }; +template +class gru { + public: + HOSTDEVICE void operator()(T *value_reset_gate, T *grad_reset_gate, + T *value_update_gate, T *grad_update_gate, + T *value_frame_state, T *grad_frame_state, + T *value_prev_out, T *grad_prev_out, + T *grad_output, T *value_reset_output, + T *grad_reset_output, ActivationType act_node, + ActivationType act_gate) { + *grad_update_gate = + activation((*grad_output) * ((*value_prev_out) - (*value_frame_state)), + (*value_update_gate), act_gate); + *grad_prev_out += (*grad_output * (*value_update_gate)); + *grad_frame_state = + activation(*grad_output * (static_cast(1.0) - (*value_update_gate)), + *value_frame_state, act_node); + T reset_output = (*value_reset_output) / (*value_reset_gate); + *grad_reset_gate = activation(reset_output * (*grad_frame_state), + *value_reset_gate, act_gate); + *grad_reset_output = (*value_reset_gate) * (*grad_frame_state); + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + HOSTDEVICE void operator()(__m256 *value_reset_gate, __m256 *grad_reset_gate, + __m256 *value_update_gate, + __m256 *grad_update_gate, + __m256 *value_frame_state, + __m256 *grad_frame_state, __m256 *value_prev_out, + __m256 *grad_prev_out, __m256 *grad_output, + __m256 *value_reset_output, + __m256 *grad_reset_output, ActivationType act_node, + ActivationType act_gate) { + *grad_update_gate = activation( + _mm256_mul_ps(*grad_output, + _mm256_sub_ps(*value_prev_out, *value_frame_state)), + *value_update_gate, act_gate); + *grad_prev_out = _mm256_add_ps( + *grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate)); + *grad_frame_state = activation( + _mm256_mul_ps(*grad_output, + _mm256_sub_ps(_mm256_set1_ps(1.0f), *value_update_gate)), + *value_frame_state, act_node); + __m256 reset_output = _mm256_div_ps(*value_reset_output, *value_reset_gate); + *grad_reset_gate = + activation(_mm256_mul_ps(reset_output, *grad_frame_state), + *value_reset_gate, act_gate); + *grad_reset_output = _mm256_mul_ps(*value_reset_gate, *grad_frame_state); + } +#endif +#endif +}; } // namespace backward diff --git a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h index ad79c58063..1e7b4b35f7 100644 --- a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/lstm_compute.h" @@ -28,6 +30,11 @@ namespace operators { namespace math { namespace detail { +using Array1 = Eigen::DSizes; +template +using EigenVector = framework::EigenVector; + #ifndef __NVCC__ template @@ -35,7 +42,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frame_size, T cell_clip, ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + ActivationType active_state, + bool old_api_version) { T r_value_in; T r_value_ig; T r_value_fg; @@ -48,10 +56,15 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, T r_state_atv; T r_out; - T *value_in = value.gate_value; - T *value_ig = value.gate_value + frame_size; - T *value_fg = value.gate_value + frame_size * 2; + T *value_ig = value.gate_value; + T *value_fg = value.gate_value + frame_size; + T *value_in = value.gate_value + frame_size * 2; T *value_og = value.gate_value + frame_size * 3; + if (old_api_version) { + value_in = value.gate_value; + value_ig = value.gate_value + frame_size; + value_fg = value.gate_value + frame_size * 2; + } for (int i = 0; i < frame_size; i++) { r_value_in = value_in[i]; @@ -85,7 +98,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, T cell_clip, ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + ActivationType active_state, + bool old_api_version) { T r_value_in; T r_value_ig; T r_value_fg; @@ -107,14 +121,25 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, T r_checkFGrad; T r_checkOGrad; - T *value_in = value.gate_value; - T *value_ig = value.gate_value + frame_size; - T *value_fg = value.gate_value + frame_size * 2; + T *value_ig = value.gate_value; + T *value_fg = value.gate_value + frame_size; + T *value_in = value.gate_value + frame_size * 2; T *value_og = value.gate_value + frame_size * 3; - T *grad_in = grad.gate_grad; - T *grad_ig = grad.gate_grad + frame_size; - T *grad_fg = grad.gate_grad + frame_size * 2; + if (old_api_version) { + value_in = value.gate_value; + value_ig = value.gate_value + frame_size; + value_fg = value.gate_value + frame_size * 2; + } + + T *grad_ig = grad.gate_grad; + T *grad_fg = grad.gate_grad + frame_size; + T *grad_in = grad.gate_grad + frame_size * 2; T *grad_og = grad.gate_grad + frame_size * 3; + if (old_api_version) { + grad_in = grad.gate_grad; + grad_ig = grad.gate_grad + frame_size; + grad_fg = grad.gate_grad + frame_size * 2; + } for (int i = 0; i < frame_size; i++) { r_value_in = value_in[i]; @@ -158,7 +183,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frame_size, T cell_clip, ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + ActivationType active_state, + bool old_api_version) { #ifdef __AVX__ __m256 r_value_in; __m256 r_value_ig; @@ -172,12 +198,17 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, __m256 r_state_atv; __m256 r_out; - __m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); - __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); - __m256 *value_fg = + __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value); + __m256 *value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size); + __m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); __m256 *value_og = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); + if (old_api_version) { + value_in = reinterpret_cast<__m256 *>(value.gate_value); + value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); + value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); + } for (int i = 0; i < frame_size / 8; i++) { r_value_in = value_in[i]; @@ -191,7 +222,8 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, } if (value.prev_state_value) { - r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; + r_prev_state = + (reinterpret_cast<__m256 const *>(value.prev_state_value))[i]; } op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, @@ -214,7 +246,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, LstmMetaGrad grad, int frame_size, T cell_clip, ActivationType active_node, ActivationType active_gate, - ActivationType active_state) { + ActivationType active_state, + bool old_api_version) { #ifdef __AVX__ __m256 r_value_in; __m256 r_value_ig; @@ -237,16 +270,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, __m256 r_checkFGrad; __m256 r_checkOGrad; - __m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value); - __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); - __m256 *value_fg = + __m256 *value_ig = reinterpret_cast<__m256 *>(value.gate_value); + __m256 *value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size); + __m256 *value_in = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); __m256 *value_og = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 3); - __m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad); - __m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size); - __m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2); + if (old_api_version) { + value_in = reinterpret_cast<__m256 *>(value.gate_value); + value_ig = reinterpret_cast<__m256 *>(value.gate_value + frame_size); + value_fg = reinterpret_cast<__m256 *>(value.gate_value + frame_size * 2); + } + + __m256 *grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad); + __m256 *grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size); + __m256 *grad_in = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2); __m256 *grad_og = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 3); + if (old_api_version) { + grad_in = reinterpret_cast<__m256 *>(grad.gate_grad); + grad_ig = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size); + grad_fg = reinterpret_cast<__m256 *>(grad.gate_grad + frame_size * 2); + } for (int i = 0; i < frame_size / 8; i++) { r_value_in = value_in[i]; @@ -263,7 +307,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, r_output_grad = (reinterpret_cast<__m256 *>(grad.output_grad))[i]; r_state_grad = (reinterpret_cast<__m256 *>(grad.state_grad))[i]; if (value.prev_state_value) { - r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; + r_prev_state = + (reinterpret_cast<__m256 const *>(value.prev_state_value))[i]; } op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, @@ -292,30 +337,133 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, #endif } +template +void eigen_lstm_forward_one_sequence(const platform::CPUDeviceContext &context, + LstmMetaValue value, int frame_size) { + auto eigen_value_ig = + typename EigenVector::Type(value.gate_value, Array1(frame_size)); + auto eigen_value_fg = typename EigenVector::Type( + value.gate_value + frame_size, Array1(frame_size)); + auto eigen_value_in = typename EigenVector::Type( + value.gate_value + frame_size * 2, Array1(frame_size)); + auto eigen_value_og = typename EigenVector::Type( + value.gate_value + frame_size * 3, Array1(frame_size)); + auto eigen_state = + typename EigenVector::Type(value.state_value, Array1(frame_size)); + auto eigen_state_act = typename EigenVector::Type(value.state_active_value, + Array1(frame_size)); + auto eigen_output = + typename EigenVector::Type(value.output_value, Array1(frame_size)); + + auto &place = *context.eigen_device(); + TanhFunctor()(place, eigen_value_in, eigen_value_in); + SigmoidFunctor()(place, eigen_value_ig, eigen_value_ig); + SigmoidFunctor()(place, eigen_value_fg, eigen_value_fg); + SigmoidFunctor()(place, eigen_value_og, eigen_value_og); + + eigen_state.device(place) = eigen_value_in * eigen_value_ig; + if (value.prev_state_value) { + auto eigen_prev_state = typename EigenVector::ConstType( + value.prev_state_value, Array1(frame_size)); + eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg; + } + + TanhFunctor()(place, eigen_state, eigen_state_act); + eigen_output.device(place) = eigen_value_og * eigen_state_act; +} + +template +void eigen_lstm_backward_one_sequence(const platform::CPUDeviceContext &context, + LstmMetaValue value, + LstmMetaGrad grad, int frame_size) { + auto eigen_value_ig = + typename EigenVector::Type(value.gate_value, Array1(frame_size)); + auto eigen_value_fg = typename EigenVector::Type( + value.gate_value + frame_size, Array1(frame_size)); + auto eigen_value_in = typename EigenVector::Type( + value.gate_value + frame_size * 2, Array1(frame_size)); + auto eigen_value_og = typename EigenVector::Type( + value.gate_value + frame_size * 3, Array1(frame_size)); + auto eigen_state_act = typename EigenVector::Type(value.state_active_value, + Array1(frame_size)); + + auto eigen_grad_ig = + typename EigenVector::Type(grad.gate_grad, Array1(frame_size)); + auto eigen_grad_fg = typename EigenVector::Type( + grad.gate_grad + frame_size, Array1(frame_size)); + auto eigen_grad_in = typename EigenVector::Type( + grad.gate_grad + frame_size * 2, Array1(frame_size)); + auto eigen_grad_og = typename EigenVector::Type( + grad.gate_grad + frame_size * 3, Array1(frame_size)); + auto eigen_grad_output = + typename EigenVector::Type(grad.output_grad, Array1(frame_size)); + auto eigen_grad_state = + typename EigenVector::Type(grad.state_grad, Array1(frame_size)); + + auto &place = *context.eigen_device(); + SigmoidGradFunctor()(place, 1 /*useless*/, eigen_value_og, + eigen_grad_output * eigen_state_act, eigen_grad_og); + eigen_grad_state.device(place) = + eigen_grad_state + + eigen_grad_output * eigen_value_og * + (static_cast(1) - eigen_state_act * eigen_state_act); + TanhGradFunctor()(place, 1, eigen_value_in, + eigen_grad_state * eigen_value_ig, eigen_grad_in); + SigmoidGradFunctor()(place, 1, eigen_value_ig, + eigen_grad_state * eigen_value_in, eigen_grad_ig); + if (value.prev_state_value) { + auto eigen_prev_state = typename EigenVector::ConstType( + value.prev_state_value, Array1(frame_size)); + SigmoidGradFunctor()(place, 1, eigen_value_fg, + eigen_grad_state * eigen_prev_state, eigen_grad_fg); + } else { + SigmoidGradFunctor()(place, 1, eigen_value_fg, 0, eigen_grad_fg); + } + if (grad.prev_state_grad) { + auto eigen_grad_pre_state = + typename EigenVector::Type(grad.prev_state_grad, Array1(frame_size)); + eigen_grad_pre_state.device(place) = eigen_grad_state * eigen_value_fg; + } +} + template -void cpu_lstm_forward(Op op, LstmMetaValue value, int frame_size, - T cell_clip, ActivationType active_node, - ActivationType active_gate, ActivationType active_state) { - if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_forward_one_sequence(op, value, frame_size, cell_clip, - active_node, active_gate, active_state); +void cpu_lstm_forward(const platform::CPUDeviceContext &context, Op op, + LstmMetaValue value, int frame_size, T cell_clip, + ActivationType active_node, ActivationType active_gate, + ActivationType active_state, bool old_api_version) { + if (!old_api_version) { + eigen_lstm_forward_one_sequence(context, value, frame_size); } else { - naive_lstm_forward_one_sequence(op, value, frame_size, cell_clip, - active_node, active_gate, active_state); + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_forward_one_sequence(op, value, frame_size, cell_clip, + active_node, active_gate, active_state, + old_api_version); + } else { + naive_lstm_forward_one_sequence(op, value, frame_size, cell_clip, + active_node, active_gate, active_state, + old_api_version); + } } } template -void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, +void cpu_lstm_backward(const platform::CPUDeviceContext &context, Op op, + LstmMetaValue value, LstmMetaGrad grad, int frame_size, T cell_clip, ActivationType active_node, - ActivationType active_gate, - ActivationType active_state) { - if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { - avx_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, - active_node, active_gate, active_state); + ActivationType active_gate, ActivationType active_state, + bool old_api_version) { + if (!old_api_version) { + eigen_lstm_backward_one_sequence(context, value, grad, frame_size); } else { - naive_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, - active_node, active_gate, active_state); + if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same::value)) { + avx_lstm_backward_one_sequence(op, value, grad, frame_size, cell_clip, + active_node, active_gate, active_state, + old_api_version); + } else { + naive_lstm_backward_one_sequence(op, value, grad, frame_size, + cell_clip, active_node, active_gate, + active_state, old_api_version); + } } } diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index 4b8a6274cc..aa726118de 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -11,6 +11,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/gru_compute.h" +#include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h" @@ -101,11 +102,64 @@ struct GRUUnitGradFunctor { } }; +template +struct GRUUnitFunctorV2 { + static void compute(const platform::CPUDeviceContext &context, + GRUMetaValue value, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { +#ifndef __NVCC__ + auto blas = math::GetBlas(context); + if (value.prev_out_value) { + blas.GEMM(CblasNoTrans, CblasTrans, batch_size, frame_size, frame_size, 1, + value.prev_out_value, value.state_weight, 0, + value.reset_output_value); + } + detail::forward_reset_output(detail::forward::gru_resetOutput(), value, + frame_size, batch_size, active_gate, false); + + T *cell_state_value = value.gate_value + 2 * frame_size; + T *reset_output_value = value.reset_output_value; + for (int b = 0; b < batch_size; ++b) { + blas.VADD(frame_size, cell_state_value, reset_output_value, + cell_state_value); + cell_state_value += frame_size * 3; + reset_output_value += frame_size; + } + + detail::forward_final_output(detail::forward::gru_finalOutput(), value, + frame_size, batch_size, active_node, true, + false); +#endif + } +}; + +template +struct GRUUnitGradFunctorV2 { + static void compute(const platform::CPUDeviceContext &context, + GRUMetaValue value, GRUMetaGrad grad, + int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { +#ifndef __NVCC__ + // calculate grad_update_gate, grad_frame_state, + // grad_reset_output, grad_reset_gate + detail::cpu_gru_backward(detail::backward::gru(), value, grad, + frame_size, batch_size, active_node, active_gate); +#endif + } +}; + template struct GRUUnitFunctor; template struct GRUUnitFunctor; template struct GRUUnitGradFunctor; template struct GRUUnitGradFunctor; +template struct GRUUnitFunctorV2; +template struct GRUUnitFunctorV2; +template struct GRUUnitGradFunctorV2; +template struct GRUUnitGradFunctorV2; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/gru_compute.h b/paddle/fluid/operators/math/gru_compute.h index f5ddec0aaa..cd713d1929 100644 --- a/paddle/fluid/operators/math/gru_compute.h +++ b/paddle/fluid/operators/math/gru_compute.h @@ -21,12 +21,13 @@ namespace math { template struct GRUMetaValue { - T *gate_weight; - T *state_weight; + const T *gate_weight; + const T *state_weight; + const T *reset_bias; T *gate_value; T *reset_output_value; T *output_value; - T *prev_out_value; + const T *prev_out_value; }; template @@ -37,6 +38,7 @@ struct GRUMetaGrad { T *reset_output_grad; T *output_grad; T *prev_out_grad; + T *state_bias_grad; }; template @@ -57,6 +59,22 @@ struct GRUUnitGradFunctor { bool origin_mode); }; +template +struct GRUUnitFunctorV2 { + static void compute(const DeviceContext &context, GRUMetaValue value, + int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate); +}; + +template +struct GRUUnitGradFunctorV2 { + static void compute(const DeviceContext &context, GRUMetaValue value, + GRUMetaGrad grad, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/lstm_compute.cc b/paddle/fluid/operators/math/lstm_compute.cc index 7e74f68801..aa4fe65a52 100644 --- a/paddle/fluid/operators/math/lstm_compute.cc +++ b/paddle/fluid/operators/math/lstm_compute.cc @@ -33,10 +33,12 @@ struct LstmUnitFunctor { LstmMetaValue value, int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act) { + const detail::ActivationType& cand_act, + bool old_api_version = true) { for (int b = 0; b < batch_size; b++) { - detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, - cell_clip, cand_act, gate_act, cell_act); + detail::cpu_lstm_forward(context, detail::forward::lstm(), value, + frame_size, cell_clip, cand_act, gate_act, + cell_act, old_api_version); value.gate_value += frame_size * 4; value.state_value += frame_size; value.state_active_value += frame_size; @@ -55,11 +57,12 @@ struct LstmUnitGradFunctor { int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act) { + const detail::ActivationType& cand_act, + bool old_api_version = true) { for (int b = 0; b < batch_size; b++) { - detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, - frame_size, cell_clip, cand_act, gate_act, - cell_act); + detail::cpu_lstm_backward(context, detail::backward::lstm(), value, + grad, frame_size, cell_clip, cand_act, gate_act, + cell_act, old_api_version); value.gate_value += frame_size * 4; value.state_value += frame_size; diff --git a/paddle/fluid/operators/math/lstm_compute.cu b/paddle/fluid/operators/math/lstm_compute.cu index e7445d3d40..4342cb7b79 100644 --- a/paddle/fluid/operators/math/lstm_compute.cu +++ b/paddle/fluid/operators/math/lstm_compute.cu @@ -26,7 +26,8 @@ struct LstmUnitFunctor { LstmMetaValue value, int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act) { + const detail::ActivationType& cand_act, + bool old_api_version = true) { detail::gpu_lstm_forward(context, detail::forward::lstm(), value, frame_size, batch_size, cell_clip, cand_act, gate_act, cell_act); @@ -40,7 +41,8 @@ struct LstmUnitGradFunctor { int frame_size, int batch_size, T cell_clip, const detail::ActivationType& gate_act, const detail::ActivationType& cell_act, - const detail::ActivationType& cand_act) { + const detail::ActivationType& cand_act, + bool old_api_version = true) { detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, frame_size, batch_size, cell_clip, cand_act, gate_act, cell_act); diff --git a/paddle/fluid/operators/math/lstm_compute.h b/paddle/fluid/operators/math/lstm_compute.h index 80af563938..cc91f784f3 100644 --- a/paddle/fluid/operators/math/lstm_compute.h +++ b/paddle/fluid/operators/math/lstm_compute.h @@ -25,7 +25,7 @@ namespace math { template struct LstmMetaValue { T *gate_value; - T *prev_state_value; + const T *prev_state_value; T *state_value; T *state_active_value; T *output_value; @@ -53,7 +53,8 @@ class LstmUnitFunctor { int frame_size, int batch_size, T cell_clip, const detail::ActivationType &gate_act, const detail::ActivationType &cell_act, - const detail::ActivationType &cand_act); + const detail::ActivationType &cand_act, + bool old_api_version = true); }; template @@ -63,7 +64,8 @@ class LstmUnitGradFunctor { LstmMetaGrad grad, int frame_size, int batch_size, T cell_clip, const detail::ActivationType &gate_act, const detail::ActivationType &cell_act, - const detail::ActivationType &cand_act); + const detail::ActivationType &cand_act, + bool old_api_version = true); }; } // namespace math diff --git a/paddle/fluid/operators/rnn_op.cc b/paddle/fluid/operators/rnn_op.cc index dfdd32e10b..2c1fcb104a 100644 --- a/paddle/fluid/operators/rnn_op.cc +++ b/paddle/fluid/operators/rnn_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/rnn_op.h" #include #include #include "paddle/fluid/framework/op_registry.h" @@ -251,5 +252,10 @@ REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker, ops::RNNGradOpMaker); REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp); -REGISTER_OP_CPU_KERNEL(rnn, ops::NotImpleKernel); -REGISTER_OP_CPU_KERNEL(rnn_grad, ops::NotImpleKernel); +REGISTER_OP_CPU_KERNEL( + rnn, ops::RNNCPUKernel, + ops::RNNCPUKernel); + +REGISTER_OP_CPU_KERNEL( + rnn_grad, ops::RNNCPUGradKernel, + ops::RNNCPUGradKernel); diff --git a/paddle/fluid/operators/rnn_op.cu.cc b/paddle/fluid/operators/rnn_op.cu.cc index f38bfd5968..5afccad177 100644 --- a/paddle/fluid/operators/rnn_op.cu.cc +++ b/paddle/fluid/operators/rnn_op.cu.cc @@ -524,6 +524,12 @@ class RNNGradCudnnKernel : public framework::OpKernel { offset += len; } + Tensor input_grad_value; + if (!in_grad) { + in_grad = &input_grad_value; + in_grad->Resize(input->dims()); + } + auto *init_h_data = pre_state[0]->data(); // auto *last_h_data = state[0]->data(); auto *last_h_grad_data = state_grad[0]->data(); diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h new file mode 100644 index 0000000000..599cb31dea --- /dev/null +++ b/paddle/fluid/operators/rnn_op.h @@ -0,0 +1,2085 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/operators/math/detail/activation_functions.h" +#include "paddle/fluid/operators/math/fc.h" +#include "paddle/fluid/operators/math/gru_compute.h" +#include "paddle/fluid/operators/math/lstm_compute.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/unique_op.h" +#include "paddle/fluid/operators/utils.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; +using TensorList = std::vector; + +#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \ + inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \ + const std::string& mode = ctx.Attr("mode"); \ + return mode == #MODE_STR; \ + } + +DEFINE_MODE_DETECTOR(lstm, LSTM); +DEFINE_MODE_DETECTOR(gru, GRU); +DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU); +DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH); + +void SwapPoniter(Tensor** a, Tensor** b) { + Tensor* c = *a; + *a = *b; + *b = c; +} + +template +void create_mask_matrix(const framework::ExecutionContext& context, + const Tensor* sequence_length, Tensor* mask_matrix, + const bool& is_reverse, int* min_seq_len) { + const auto& seq_len_vec = GetDataFromTensor(sequence_length); + const int& table_width = mask_matrix->dims()[0]; + Tensor temp; + temp.Resize( + framework::make_ddim({mask_matrix->dims()[1], mask_matrix->dims()[0]})); + T* data_temp = temp.mutable_data(context.GetPlace()); + std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast(1.0)); + *min_seq_len = table_width; + for (unsigned int i = 0; i < seq_len_vec.size(); i++) { + // reset the mask matrix + *min_seq_len = std::min(seq_len_vec[i], *min_seq_len); + if (seq_len_vec[i] == table_width) { + continue; + } + if (is_reverse) { + std::fill(data_temp + i * table_width, + data_temp + (i + 1) * table_width - seq_len_vec[i], + static_cast(0)); + } else { + std::fill(data_temp + i * table_width + seq_len_vec[i], + data_temp + (i + 1) * table_width, static_cast(0)); + } + } + mask_matrix->mutable_data(context.GetPlace()); + std::vector trans_vec; + trans_vec.emplace_back(1); + trans_vec.emplace_back(0); + auto& dev_ctx = context.template device_context(); + TransCompute(2, dev_ctx, temp, mask_matrix, + trans_vec); +} + +template +struct Cell { + virtual ~Cell() {} + virtual void operator()(const platform::CPUDeviceContext* device_ctx, + Tensor* input, const Tensor* weight_hh, + const Tensor* init_h, const Tensor* init_c, + Tensor* last_h, Tensor* last_c, Tensor* last_c_act, + Tensor* output, const Tensor* bias_hh, + Tensor* weight_hh_gru) const {} +}; + +template class EigenActivationFunctor, + math::detail::ActivationType act_type> +struct SimpleRNNCell : Cell { + void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, + const Tensor* weight_hh, const Tensor* init_h, + const Tensor* init_c, Tensor* last_h, Tensor* last_c, + Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, + Tensor* weight_hh_gru) const override { + auto blas = math::GetBlas(*device_ctx); + auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true); + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + // convert the batch matmul to matmul, this operator could be speed faster + blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast(1.0), + input, static_cast(1.0)); + auto z = EigenVector::Flatten( + GET_DATA_SAFELY(input, "Input", "z", "Activation")); + auto hidden = EigenVector::Flatten( + GET_DATA_SAFELY(output, "Output", "hidden", "Activation")); + + auto* place = device_ctx->eigen_device(); + EigenActivationFunctor functor; + functor(*place, z, hidden); + } +}; + +template +struct GRUCell : Cell { + void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, + const Tensor* weight_hh, const Tensor* init_h, + const Tensor* init_c, Tensor* last_h, Tensor* last_c, + Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, + Tensor* weight_hh_gru) const override { + auto blas = math::GetBlas(*device_ctx); + auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false); + auto mat_dim_b = + math::CreateMatrixDescriptor(weight_hh_gru->dims(), 0, true); + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + // convert the batch matmul to matmul, this operator could be speed faster + blas.MatMul(*init_h, mat_dim_a, *weight_hh_gru, mat_dim_b, + static_cast(1.0), input, static_cast(1.0)); + size_t frame_size = init_h->dims()[2]; + size_t batch_size = init_h->dims()[1]; + + math::GRUMetaValue gru_value; + gru_value.gate_weight = weight_hh->data(); + gru_value.state_weight = weight_hh->data() + 2 * frame_size * frame_size; + gru_value.reset_bias = bias_hh->data() + 2 * frame_size; + + gru_value.gate_value = input->data(); + gru_value.reset_output_value = last_c->data(); + gru_value.output_value = output->data(); + gru_value.prev_out_value = init_h->data(); + + auto gate_act = math::detail::GetActivationType("sigmoid_v2"); + auto cand_act = math::detail::GetActivationType("tanh_v2"); + + math::GRUUnitFunctorV2::compute( + *device_ctx, gru_value, frame_size, batch_size, cand_act, gate_act); + } +}; + +template +struct LSTMCell : Cell { + void operator()(const platform::CPUDeviceContext* device_ctx, Tensor* input, + const Tensor* weight_hh, const Tensor* init_h, + const Tensor* init_c, Tensor* last_h, Tensor* last_c, + Tensor* last_c_act, Tensor* output, const Tensor* bias_hh, + Tensor* weight_hh_gru) const override { + auto blas = math::GetBlas(*device_ctx); + auto mat_dim_a = math::CreateMatrixDescriptor(init_h->dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, true); + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + // convert the batch matmul to matmul, this operator could be speed faster + blas.MatMul(*init_h, mat_dim_a, *weight_hh, mat_dim_b, static_cast(1.0), + input, static_cast(1.0)); + + math::LstmMetaValue lstm_value; + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; + + auto gate_act = math::detail::GetActivationType("sigmoid_v2"); + auto cell_act = math::detail::GetActivationType("tanh_v2"); + auto cand_act = math::detail::GetActivationType("tanh_v2"); + + size_t frame_size = init_h->dims()[2]; + size_t batch_size = init_h->dims()[1]; + + Tensor cell_pre_act; + if (last_c_act == nullptr) { /* is test */ + cell_pre_act.mutable_data(init_h->dims(), device_ctx->GetPlace()); + last_c_act = &cell_pre_act; + } + + lstm_value.prev_state_value = init_c->data(); + lstm_value.gate_value = input->data(); + lstm_value.output_value = output->data(); + lstm_value.state_value = last_c->data(); + lstm_value.state_active_value = last_c_act->data(); + T cell_clip = 0.0; + math::LstmUnitFunctor::compute( + *device_ctx, lstm_value, frame_size, batch_size, cell_clip, gate_act, + cell_act, cand_act, false); + } +}; + +template +void dropout_cpu_function_inplace(const framework::ExecutionContext& context, + Tensor* x, Tensor* mask, + const float& dropout_prob, + const int& seed_number, const bool& is_test, + bool* is_has_reset) { + if (is_test) { + return; + } + auto* x_data = x->data(); + size_t size = framework::product(x->dims()); + auto* mask_data = mask->data(); + if (!(*is_has_reset)) { + // Special case when dropout_prob is 1.0 + if (dropout_prob == 1.0f) { + std::fill(x_data, x_data + size, static_cast(0)); + std::fill(mask_data, mask_data + size, static_cast(0)); + *is_has_reset = true; + return; + } + auto engine = framework::GetCPURandomEngine(seed_number); + std::uniform_real_distribution dist(0, 1); + for (size_t i = 0; i < size; ++i) { + if (dist(*engine) < dropout_prob) { + mask_data[i] = 0; + x_data[i] = static_cast(0); + } else { + mask_data[i] = 1; + x_data[i] /= static_cast(1.0f - dropout_prob); + } + } + *is_has_reset = true; + } else { + if (dropout_prob == 1.0f) { + std::fill(x_data, x_data + size, static_cast(0)); + return; + } + for (size_t i = 0; i < size; ++i) { + if (mask_data[i] == 0) { + x_data[i] = static_cast(0); + } else { + x_data[i] /= static_cast(1.0f - dropout_prob); + } + } + } +} + +template +void dropout_cpu_grad_function_inplace( + const framework::ExecutionContext& context, Tensor* grad_x, + const Tensor* mask, const float& dropout_prob) { + auto& place = *context.template device_context() + .eigen_device(); + auto M = EigenVector::Flatten(*mask); + auto dX = EigenVector::Flatten(*grad_x); + if (dropout_prob == 1.0f) { + dX.device(place) = static_cast(0) * dX; + } else { + dX.device(place) = dX * M.cast() / static_cast(1.0f - dropout_prob); + } +} + +template +struct Layer { + explicit Layer(const CellType& cell) : cell_(cell) {} + virtual ~Layer() {} + void preprocess(const framework::ExecutionContext& context, + const Tensor* input, const Tensor& weight, + const Tensor& bias_ih, const Tensor& bias_hh, + Tensor* cache_input, bool is_test) { + // crate the temp input for the X * W_ih^T + Bias_ih + auto& dev_ctx = + context.template device_context(); + const int& hidden_size = weight.dims()[0]; + cache_input->Resize(framework::make_ddim( + {input->dims()[0], input->dims()[1], hidden_size})); + if (is_test) { + cache_input->mutable_data(context.GetPlace()); + } + auto blas = math::GetBlas(dev_ctx); + auto mat_dim_a = math::CreateMatrixDescriptor(input->dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(weight.dims(), 0, true); + // convert the batch matmul to matmul, this operator could be speed faster + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast(1.0), + cache_input, static_cast(0)); + + auto eigen_in = framework::EigenMatrix::Reshape( + *cache_input, cache_input->dims().size() - 1); + auto eigen_bias_ih = framework::EigenMatrix::From( + bias_ih, framework::make_ddim({1, bias_ih.dims()[0]})); + const int& row_num = + framework::product(cache_input->dims()) / cache_input->dims()[2]; + eigen_in = + eigen_in + eigen_bias_ih.broadcast(Eigen::DSizes(row_num, 1)); + if (is_gru(context)) { + // reset_gate update_gate cell_gate = [1, 1, 0] + Tensor bias_hh_tmp; + bias_hh_tmp.Resize({bias_hh.numel()}); + bias_hh_tmp.mutable_data(context.GetPlace()); + framework::TensorCopy(bias_hh, context.GetPlace(), dev_ctx, &bias_hh_tmp); + bias_hh_tmp.Resize({3, bias_hh_tmp.numel() / 3}); + auto bias_hh_tmp_unbind = Unbind(bias_hh_tmp); + math::SetConstant zero; + zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast(0.0)); + + auto eigen_bias_hh_tmp = framework::EigenMatrix::From( + bias_hh_tmp, framework::make_ddim({1, bias_hh.dims()[0]})); + eigen_in = eigen_in + + eigen_bias_hh_tmp.broadcast(Eigen::DSizes(row_num, 1)); + } else { + auto eigen_bias_hh = framework::EigenMatrix::From( + bias_hh, framework::make_ddim({1, bias_hh.dims()[0]})); + eigen_in = + eigen_in + eigen_bias_hh.broadcast(Eigen::DSizes(row_num, 1)); + } + } + + void postprocess(const framework::ExecutionContext& context, Tensor* output, + const Tensor* init_h, const Tensor* init_c, Tensor* last_h, + Tensor* last_c, const Tensor& mask_tensor) { + // in the output, if mask flag is 0, we will retun the zero data + auto& place = *context.template device_context() + .eigen_device(); + auto eigen_output = + framework::EigenMatrix::Reshape(*output, output->dims().size() - 1); + auto eigen_mask = framework::EigenMatrix::From( + mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1})); + auto eigen_init_h = + framework::EigenMatrix::Reshape(*init_h, init_h->dims().size() - 1); + auto eigen_last_h = + framework::EigenMatrix::Reshape(*last_h, last_h->dims().size() - 1); + auto eigen_mask_broadcast = + eigen_mask.broadcast(Eigen::DSizes(1, output->dims()[2])); + eigen_last_h.device(place) = eigen_output * eigen_mask_broadcast + + eigen_init_h * (1 - eigen_mask_broadcast); + eigen_output.device(place) = eigen_output * eigen_mask_broadcast; + + if (is_lstm(context)) { + auto eigen_init_c = framework::EigenMatrix::Reshape( + *init_c, init_c->dims().size() - 1); + auto eigen_last_c = framework::EigenMatrix::Reshape( + *last_c, last_c->dims().size() - 1); + eigen_last_c.device(place) = eigen_last_c * eigen_mask_broadcast + + eigen_init_c * (1 - eigen_mask_broadcast); + } + } + + virtual void operator()(const framework::ExecutionContext& context, + const Tensor* input, const TensorList& vec, + const TensorList& init_h, const TensorList& init_c, + const Tensor* sequence_length, TensorList last_h, + TensorList last_c, Tensor* output, + const int& layer_idx, const int& gate_num, + Tensor* gate_value, Tensor* cell_value, + Tensor* cell_act_value, bool is_test) {} + + void RunTestIter(const framework::ExecutionContext& context, + const Tensor* input, const TensorList& vec, + const TensorList& init_h, const TensorList& init_c, + const Tensor* sequence_length, TensorList* last_h_ptr, + TensorList* last_c_ptr, Tensor* output, int layer_idx, + Tensor* gate_value, Tensor* cell_value, + Tensor* cell_act_value, bool is_bidirect, int offset) { + bool is_reverse = false; + if (is_bidirect) { + layer_idx = 2 * layer_idx + offset; + if (offset > 0) { + is_reverse = true; + } + } + auto& dev_ctx = + context.template device_context(); + const int& time_step = input->dims()[0]; + this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4], + vec[3 + offset * 4], gate_value, true); + auto input_tensors = Unbind(*gate_value); + auto output_tensors = Unbind(*output); + if (is_reverse) { + std::reverse(input_tensors.begin(), input_tensors.end()); + std::reverse(output_tensors.begin(), output_tensors.end()); + } + TensorList mask_tensor_list; + // construct the mask matrix for the mask + bool has_sequence_length = false; + if (sequence_length != nullptr) { + has_sequence_length = true; + } + Tensor mask_matrix; + int mask_min_length = time_step; + if (has_sequence_length) { + mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]})); + + create_mask_matrix(context, sequence_length, &mask_matrix, is_reverse, + &mask_min_length); + mask_tensor_list = Unbind(mask_matrix); + } + if (is_reverse) { + mask_min_length = mask_min_length - time_step + 1; + } + bool has_allocate_mem_c = false; + bool has_use_last_h_holder = false; + const int& reverse_flag = is_reverse ? -1 : 1; + + // define the init_h holder for the swap + Tensor init_h_temp; + framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx, + &init_h_temp); + Tensor* init_h_holder = &init_h_temp; + Tensor* last_h_holder = nullptr; + if (0 < mask_min_length) { + last_h_holder = &(output_tensors[0]); + } else { + last_h_holder = &(*last_h_ptr)[layer_idx]; + has_use_last_h_holder = true; + } + + Tensor* init_c_holder = nullptr; + const Tensor* init_c_temp_holder = nullptr; + Tensor init_c_temp; + Tensor* last_c_holder = nullptr; + Tensor last_c_temp; + + if (is_lstm(context)) { + last_c_holder = &(*last_c_ptr)[layer_idx]; + init_c_temp_holder = &init_c[layer_idx]; + } else if (is_gru(context)) { + // for reset output value + last_c_temp.Resize(init_h[layer_idx].dims()); + last_c_temp.mutable_data(context.GetPlace()); + last_c_holder = &last_c_temp; + } + Tensor weight_hh_tmp; // for gru + if (is_gru(context)) { + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + weight_hh_tmp.mutable_data(context.GetPlace()); + framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx, + &weight_hh_tmp); + weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3}); + auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp); + math::SetConstant zero; + zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast(0.0)); + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + } + for (int i = 0; i < time_step; i++) { + bool in_mask = (reverse_flag * i) >= mask_min_length; + if (i > 0) { + if (!has_allocate_mem_c) { + if (is_lstm(context) || is_gru(context)) { + init_c_temp.Resize(init_h[layer_idx].dims()); + init_c_temp.mutable_data(context.GetPlace()); + init_c_holder = &init_c_temp; + } + has_allocate_mem_c = true; + } + SwapPoniter(&init_c_holder, &last_c_holder); + init_c_temp_holder = init_c_holder; + } + cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder, + init_c_temp_holder, last_h_holder, last_c_holder, nullptr, + &output_tensors[i], &vec[3 + offset * 4] /* bias_hh */, + &weight_hh_tmp); + if (in_mask) { + this->postprocess(context, &output_tensors[i], init_h_holder, + init_c_temp_holder, last_h_holder, last_c_holder, + mask_tensor_list[i]); + } + // prepare next step + if (i + 1 < time_step) { + bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length; + if (next_step_mask) { + if (!has_use_last_h_holder) { + init_h_holder = &(*last_h_ptr)[layer_idx]; + } + } else { + init_h_holder = &(output_tensors[i + 1]); + } + SwapPoniter(&init_h_holder, &last_h_holder); + } + } + if (has_sequence_length) { + if (last_h_holder != &(*last_h_ptr)[layer_idx]) { + framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx, + &(*last_h_ptr)[layer_idx]); + } + } else { + framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(), + dev_ctx, &(*last_h_ptr)[layer_idx]); + } + + if (time_step % 2 == 0) { + if (is_lstm(context)) { + framework::TensorCopy(*last_c_holder, context.GetPlace(), dev_ctx, + &(*last_c_ptr)[layer_idx]); + } + } + } + + void RunIter(const framework::ExecutionContext& context, const Tensor* input, + const TensorList& vec, const TensorList& init_h, + const TensorList& init_c, const Tensor* sequence_length, + TensorList* last_h_ptr, TensorList* last_c_ptr, Tensor* output, + int layer_idx, Tensor* gate_value, Tensor* cell_value, + Tensor* cell_act_value, bool is_bidirect, int offset, + bool is_test) { + if (is_test) { + RunTestIter(context, input, vec, init_h, init_c, sequence_length, + last_h_ptr, last_c_ptr, output, layer_idx, gate_value, + cell_value, cell_act_value, is_bidirect, offset); + return; + } + bool is_reverse = false; + if (is_bidirect) { + layer_idx = 2 * layer_idx + offset; + if (offset > 0) { + is_reverse = true; + } + } + auto& dev_ctx = + context.template device_context(); + const int& time_step = input->dims()[0]; + this->preprocess(context, input, vec[0 + offset * 4], vec[2 + offset * 4], + vec[3 + offset * 4], gate_value, is_test); + auto input_tensors = Unbind(*gate_value); + auto output_tensors = Unbind(*output); + if (is_reverse) { + std::reverse(input_tensors.begin(), input_tensors.end()); + std::reverse(output_tensors.begin(), output_tensors.end()); + } + TensorList mask_tensor_list; + // construct the mask matrix for the mask + bool has_sequence_length = false; + if (sequence_length != nullptr) { + has_sequence_length = true; + } + Tensor mask_matrix; + int mask_min_length = time_step; + if (has_sequence_length) { + mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]})); + create_mask_matrix(context, sequence_length, &mask_matrix, is_reverse, + &mask_min_length); + mask_tensor_list = Unbind(mask_matrix); + } + if (is_reverse) { + mask_min_length = mask_min_length - time_step + 1; + } + + // define the init_h holder for the swap + bool has_use_last_h_holder = false; + const int& reverse_flag = is_reverse ? -1 : 1; + + TensorList cell_value_tensors; + TensorList cell_act_value_tensors; + + Tensor init_h_temp; + framework::TensorCopy(*&init_h[layer_idx], context.GetPlace(), dev_ctx, + &init_h_temp); + Tensor* init_h_holder = &init_h_temp; + Tensor* last_h_holder = nullptr; + if (0 < mask_min_length) { + last_h_holder = &(output_tensors[0]); + } else { + last_h_holder = &(*last_h_ptr)[layer_idx]; + has_use_last_h_holder = true; + } + + const Tensor* init_c_holder = nullptr; + Tensor* last_c_holder = nullptr; + Tensor* last_c_act_holder = nullptr; + if (is_lstm(context) || is_gru(context)) { + cell_value->Resize({time_step, cell_value->numel() / time_step}); + cell_value_tensors = Unbind(*cell_value); + if (is_lstm(context)) { + cell_act_value->Resize( + {time_step, cell_act_value->numel() / time_step}); + cell_act_value_tensors = Unbind(*cell_act_value); + } + } + Tensor weight_hh_tmp; // for gru + if (is_gru(context)) { + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + weight_hh_tmp.mutable_data(context.GetPlace()); + framework::TensorCopy(vec[1 + offset * 4], context.GetPlace(), dev_ctx, + &weight_hh_tmp); + weight_hh_tmp.Resize({3, weight_hh_tmp.numel() / 3}); + auto weight_hh_tmp_unbind = Unbind(weight_hh_tmp); + math::SetConstant zero; + zero(dev_ctx, &weight_hh_tmp_unbind[2], static_cast(0.0)); + weight_hh_tmp.Resize(vec[1 + offset * 4].dims()); + } + for (int i = 0; i < time_step; i++) { + bool in_mask = (reverse_flag * i) >= mask_min_length; + if (is_lstm(context)) { + if (i == 0) { + init_c_holder = &init_c[layer_idx]; + } else { + init_c_holder = &cell_value_tensors[i - 1]; + } + cell_value_tensors[i].Resize(init_c[layer_idx].dims()); + cell_act_value_tensors[i].Resize(init_c[layer_idx].dims()); + last_c_holder = &cell_value_tensors[i]; + last_c_act_holder = &cell_act_value_tensors[i]; + } else if (is_gru(context)) { + cell_value_tensors[i].Resize(init_h[layer_idx].dims()); + last_c_holder = &cell_value_tensors[i]; + } + + cell_(&dev_ctx, &input_tensors[i], &vec[1 + offset * 4], init_h_holder, + init_c_holder, last_h_holder, last_c_holder, last_c_act_holder, + &output_tensors[i], &vec[3 + offset * 4] /* bias_hh */, + &weight_hh_tmp); + if (in_mask) { + this->postprocess(context, &output_tensors[i], init_h_holder, + init_c_holder, last_h_holder, last_c_holder, + mask_tensor_list[i]); + } + // prepare next step + if (i + 1 < time_step) { + bool next_step_mask = (reverse_flag * (i + 1)) >= mask_min_length; + if (next_step_mask) { + if (!has_use_last_h_holder) { + init_h_holder = &(*last_h_ptr)[layer_idx]; + } + } else { + init_h_holder = &(output_tensors[i + 1]); + } + SwapPoniter(&init_h_holder, &last_h_holder); + } + } + if (has_sequence_length) { + if (last_h_holder != &(*last_h_ptr)[layer_idx]) { + framework::TensorCopy(*last_h_holder, context.GetPlace(), dev_ctx, + &(*last_h_ptr)[layer_idx]); + } + } else { + framework::TensorCopy(output_tensors[time_step - 1], context.GetPlace(), + dev_ctx, &(*last_h_ptr)[layer_idx]); + } + if (is_lstm(context)) { + framework::TensorCopy(cell_value_tensors[time_step - 1], + context.GetPlace(), dev_ctx, + &(*last_c_ptr)[layer_idx]); + } + } + // Cell for the rnn module + CellType cell_; +}; + +template +struct SingleLayer : public Layer { + explicit SingleLayer(const CellType& cell) : Layer(cell) {} + void operator()(const framework::ExecutionContext& context, + const Tensor* input, const TensorList& vec, + const TensorList& init_h, const TensorList& init_c, + const Tensor* sequence_length, TensorList last_h, + TensorList last_c, Tensor* output, const int& layer_idx, + const int& gate_num, Tensor* gate_value, Tensor* cell_value, + Tensor* cell_act_value, bool is_test) { + this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h, + &last_c, output, layer_idx, gate_value, cell_value, + cell_act_value, false, 0, is_test); + } +}; + +template +struct BidirLayer : public Layer { + explicit BidirLayer(const CellType& cell) : Layer(cell) {} + void operator()(const framework::ExecutionContext& context, + const Tensor* input, const TensorList& vec, + const TensorList& init_h, const TensorList& init_c, + const Tensor* sequence_length, TensorList last_h, + TensorList last_c, Tensor* output, const int& layer_idx, + const int& gate_num, Tensor* gate_value, Tensor* cell_value, + Tensor* cell_act_value, bool is_test) { + TensorList output_vec(2); + Tensor forward_input_w, forward_cell_value, forward_cell_act_value; + Tensor backward_input_w, backward_cell_value, backward_cell_act_value; + int time_step = input->dims()[0]; + int batch_size = input->dims()[1]; + int hidden_size = output->dims()[2]; + for (int i = 0; i < 2; ++i) { + output_vec[i].Resize({time_step, batch_size, hidden_size / 2}); + output_vec[i].mutable_data(context.GetPlace()); + } + if (!is_test) { + gate_value->Resize({2, gate_value->numel() / 2}); + forward_input_w = gate_value->Slice(0, 1); + backward_input_w = gate_value->Slice(1, 2); + + if (is_lstm(context) || is_gru(context)) /* for lstm and gru */ { + cell_value->Resize({2, cell_value->numel() / 2}); + cell_act_value->Resize({2, cell_act_value->numel() / 2}); + forward_cell_value = cell_value->Slice(0, 1); + backward_cell_value = cell_value->Slice(1, 2); + if (is_lstm(context)) { + forward_cell_act_value = cell_act_value->Slice(0, 1); + backward_cell_act_value = cell_act_value->Slice(1, 2); + } + } + } + + this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h, + &last_c, &output_vec[0], layer_idx, &forward_input_w, + &forward_cell_value, &forward_cell_act_value, true, 0, + is_test); + + this->RunIter(context, input, vec, init_h, init_c, sequence_length, &last_h, + &last_c, &output_vec[1], layer_idx, &backward_input_w, + &backward_cell_value, &backward_cell_act_value, true, 1, + is_test); + + // concat the the output result + auto& dev_ctx = + context.template device_context(); + paddle::operators::math::ConcatFunctor + concat_functor; + concat_functor(dev_ctx, output_vec, static_cast(2), output); + } +}; + +template +void SplitReserveData(const framework::ExecutionContext& ctx, + TensorType* reserve_data, Tensor* gate_data, + Tensor* cell_data, Tensor* cell_act_data, + Tensor* hidden_data, int direction_num, + const int& time_step, const int& batch_size, + const int& hidden_size, const int& gate_num, + const int& num_layers) { + const int& gate_data_idx = gate_num * num_layers; + const int& cell_data_idx = (gate_num + 1) * num_layers; + const int& cell_act_data_idx = (gate_num + 2) * num_layers; + // simple rnn + int hidden_data_start_idx = gate_data_idx; + *gate_data = reserve_data->Slice(0, gate_data_idx); + if (is_lstm(ctx)) { + *cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx); + *cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx); + hidden_data_start_idx = cell_act_data_idx; + } else if (is_gru(ctx)) { + *cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx); + hidden_data_start_idx = cell_data_idx; + } + int hidden_data_idx = hidden_data_start_idx + (num_layers - 1); + if (num_layers > 1) { + *hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx); + } +} + +template +void reset_parameter_vector(const std::vector& raw_params_vec, + const int& num_layers, const int& gate_num, + const bool& is_bidirec, + std::vector* params_vec) { + // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers + // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to + // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers + const int& direction_num = is_bidirec ? 2 : 1; + const int& layer_weight_size = 4 * direction_num; + const int& all_weight_size = num_layers * layer_weight_size; + const int& bias_start_idx = all_weight_size / 2; + for (int i = 0; i < num_layers; i++) { + TensorList tensor_list; + tensor_list.reserve(layer_weight_size); + for (int j = 0; j < layer_weight_size; j++) { + Tensor tensor_holder; + tensor_list.emplace_back(tensor_holder); + } + for (int j = 0; j < layer_weight_size; j++) { + int k = j % 4; + const int& section = j / 4; + int tensor_idx = i * 2 * direction_num + section * 2 + k % 2; + if (k >= 2) { + tensor_idx += bias_start_idx; + } + tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]); + } + params_vec->emplace_back(tensor_list); + } +} + +template +void AllocateReserveData(const framework::ExecutionContext& ctx, + Tensor* reserve_data, Tensor* gate_data, + Tensor* cell_data, Tensor* cell_act_data, + Tensor* hidden_data, const Tensor* input, + bool is_bidirec, int num_layers, int gate_num, + int hidden_size) { + const int& direction_num = is_bidirec ? 2 : 1; + const int& time_step = input->dims()[0]; + const int& batch_size = input->dims()[1]; + const int& block_size = direction_num * time_step * batch_size * hidden_size; + int hidden_data_idx = (num_layers - 1); + if (is_lstm(ctx)) { + hidden_data_idx += (gate_num + 2) * num_layers; + } else if (is_gru(ctx)) { + hidden_data_idx += (gate_num + 1) * num_layers; + } else { + hidden_data_idx += gate_num * num_layers; + } + + reserve_data->Resize({hidden_data_idx, block_size}); + reserve_data->mutable_data(ctx.GetPlace()); + SplitReserveData(ctx, reserve_data, gate_data, cell_data, cell_act_data, + hidden_data, direction_num, time_step, batch_size, + hidden_size, gate_num, num_layers); +} + +template class LayerT, + template class SingleLayerT, + template class BidirLayerT, typename T> +void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input, + const std::vector weight_list, const Tensor* init_h, + const Tensor* init_c, const Tensor* sequence_length, + Tensor* last_h, Tensor* last_c, Tensor* output, + Tensor* dropout_mask, const int& num_layers, const int& gate_num, + const int& input_size, const int& hidden_size, + const bool& is_bidirec, const std::string& cell_type, + const float& dropout_prob, const bool& is_test, const int& seed, + Tensor* reserve_data) { + const int& direction_num = is_bidirec ? 2 : 1; + const auto& init_h_dims = init_h->dims(); + PADDLE_ENFORCE_EQ(init_h_dims[0], num_layers * direction_num, + platform::errors::InvalidArgument( + "The num_layers of in RNN layer must be the same as " + "first dim of init hidden, but received" + " num_layers:%d, dim:%d", + num_layers, init_h_dims[0])); + if (is_lstm(ctx)) { + const auto& init_c_dims = init_c->dims(); + PADDLE_ENFORCE_EQ(init_c_dims[0], num_layers * direction_num, + platform::errors::InvalidArgument( + "The num_layers of in RNN layer must be the same as " + "first dim of cell state hidden, but received" + " num_layers:%d, dim:%d", + num_layers, init_h_dims[0])); + } + CellType cell; + + std::vector parameter_lists; + parameter_lists.reserve(num_layers); + reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec, + ¶meter_lists); + + Tensor gate_data, cell_data, cell_act_data, hidden_data; + + if (!is_test) { + AllocateReserveData( + ctx, reserve_data, &gate_data, &cell_data, &cell_act_data, &hidden_data, + input, is_bidirec, num_layers, gate_num, hidden_size); + gate_data.Resize({num_layers, gate_data.numel() / num_layers}); + cell_data.Resize({num_layers, cell_data.numel() / num_layers}); + cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers}); + + if (num_layers > 1) { + hidden_data.Resize( + {num_layers - 1, hidden_data.numel() / (num_layers - 1)}); + } + } + Tensor* input_holder; + Tensor* output_holder = output; + Tensor temp; + bool has_allocate_mem = false; + + auto init_h_unbind = Unbind(*init_h); + auto last_h_unbind = Unbind(*last_h); + TensorList init_c_unbind, last_c_unbind; + if (is_lstm(ctx)) { + init_c_unbind = Unbind(*init_c); + last_c_unbind = Unbind(*last_c); + } + + Tensor curr_gate_data, curr_cell_data, curr_cell_act_data; + Tensor curr_hidden_data, prev_hidden_data; + bool has_dropout_reset = false; + for (int i = 0; i < num_layers; i++) { + if (!is_test) { + if (cell_data.numel() > 0) /** for lstm, gru **/ { + curr_cell_data = cell_data.Slice(i, i + 1); + } + if (cell_act_data.numel() > 0) /*for lstm*/ { + curr_cell_act_data = cell_act_data.Slice(i, i + 1); + } + curr_gate_data = gate_data.Slice(i, i + 1); + output_holder = output; + if (i < num_layers - 1 && num_layers > 1) { + curr_hidden_data = hidden_data.Slice(i, i + 1); + curr_hidden_data.Resize(output->dims()); + output_holder = &curr_hidden_data; + } + } + if (i > 0) { + if (!has_allocate_mem) { + temp.Resize(output->dims()); + temp.mutable_data(ctx.GetPlace()); + input_holder = &temp; + has_allocate_mem = true; + } + if (!is_test) { + prev_hidden_data = hidden_data.Slice(i - 1, i); + input_holder = &prev_hidden_data; + input_holder->Resize(output->dims()); + } else { + SwapPoniter(&output_holder, &input_holder); + } + if (dropout_prob != 0 && (!is_test)) { + dropout_cpu_function_inplace(ctx, input_holder, dropout_mask, + dropout_prob, seed, is_test, + &has_dropout_reset); + } + } + const Tensor* input_temp_holder = input; + if (i > 0) { + input_temp_holder = input_holder; + } + LayerT* layer; + SingleLayerT slayer(cell); + BidirLayerT blayer(cell); + if (is_bidirec) { + layer = &blayer; + } else { + layer = &slayer; + } + (*layer)(ctx, input_temp_holder, parameter_lists[i], init_h_unbind, + init_c_unbind, sequence_length, last_h_unbind, last_c_unbind, + output_holder, i, gate_num, &curr_gate_data, &curr_cell_data, + &curr_cell_act_data, is_test); + } + if (num_layers % 2 == 0) { + framework::TensorCopy( + *output_holder, ctx.GetPlace(), + ctx.template device_context(), output); + } +} + +template +class RNNCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto pre_state = ctx.MultiInput("PreState"); + auto weight_list = ctx.MultiInput("WeightList"); + auto state = ctx.MultiOutput("State"); + auto* output = ctx.Output("Out"); + auto* dropout_mask = ctx.Output("DropoutState"); + auto* reserve_data = ctx.Output("Reserve"); + const int& num_layers = ctx.Attr("num_layers"); + const bool& is_bidirec = ctx.Attr("is_bidirec"); + const int& input_size = ctx.Attr("input_size"); + const int& hidden_size = ctx.Attr("hidden_size"); + const float& dropout_prob = ctx.Attr("dropout_prob"); + const std::string& mode = ctx.Attr("mode"); + const bool& is_test = ctx.Attr("is_test"); + const int& seed = ctx.Attr("seed"); + + bool has_seq_length = ctx.HasInput("SequenceLength"); + const Tensor* sequence_length = nullptr; + if (has_seq_length) { + sequence_length = ctx.Input("SequenceLength"); + } + if (!dropout_mask->IsInitialized()) { + dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); + } + + // init the output and allocate the memory + output->mutable_data(ctx.GetPlace()); + int gate_num = 4; + state[0]->mutable_data(ctx.GetPlace()); + if (is_lstm(ctx)) { + state[1]->mutable_data(ctx.GetPlace()); + RnnFunc, Layer, SingleLayer, BidirLayer, T>( + ctx, input, weight_list, pre_state[0], pre_state[1], sequence_length, + state[0], state[1], output, dropout_mask, num_layers, gate_num, + input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, + seed, reserve_data); + } else if (is_rnn_relu(ctx)) { + gate_num = 1; + RnnFunc< + SimpleRNNCell, + Layer, SingleLayer, BidirLayer, T>( + ctx, input, weight_list, pre_state[0], nullptr, sequence_length, + state[0], nullptr, output, dropout_mask, num_layers, gate_num, + input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, + seed, reserve_data); + } else if (is_rnn_tanh(ctx)) { + gate_num = 1; + RnnFunc< + SimpleRNNCell, + Layer, SingleLayer, BidirLayer, T>( + ctx, input, weight_list, pre_state[0], nullptr, sequence_length, + state[0], nullptr, output, dropout_mask, num_layers, gate_num, + input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, + seed, reserve_data); + } else if (is_gru(ctx)) { + gate_num = 3; + RnnFunc, Layer, SingleLayer, BidirLayer, T>( + ctx, input, weight_list, pre_state[0], nullptr, sequence_length, + state[0], nullptr, output, dropout_mask, num_layers, gate_num, + input_size, hidden_size, is_bidirec, mode, dropout_prob, is_test, + seed, reserve_data); + } + } +}; + +template +void create_lstm_value(math::LstmMetaValue* lstm_value) { + lstm_value->check_ig = nullptr; + lstm_value->check_fg = nullptr; + lstm_value->check_og = nullptr; +} + +template +void create_lstm_grad(math::LstmMetaGrad* lstm_grad) { + lstm_grad->check_ig_grad = nullptr; + lstm_grad->check_fg_grad = nullptr; + lstm_grad->check_og_grad = nullptr; +} + +template +void create_tensor_by_list(const framework::ExecutionContext& context, + Tensor* dst, const std::vector& v) { + int tensor_size = v.size(); + dst->Resize({tensor_size}); + dst->mutable_data(context.GetPlace()); + int size = v.size(); + for (int i = 0; i < size; ++i) { + dst->data()[i] = v[i]; + } +} + +template +void make_grad_gate_buf(const framework::ExecutionContext& context, + Tensor* grad_gate, Tensor* grad_gate_buf, + Tensor* reset_output_grad = nullptr) { + int dim_size = grad_gate->dims().size(); + int batch_size = grad_gate->dims()[dim_size - 2]; + int frame_size = grad_gate->dims()[dim_size - 1]; + + Tensor grad_gate_mask; + create_tensor_by_list(context, &grad_gate_mask, {1, 1, 0}); + + auto& place = *context.template device_context() + .eigen_device(); + auto eigen_grad_gate_mask = framework::EigenMatrix::From( + grad_gate_mask, framework::make_ddim({3, 1})); + auto eigen_grad_gate_mask_broadcast = + eigen_grad_gate_mask.broadcast(Eigen::DSizes(1, frame_size / 3)) + .reshape(Eigen::DSizes(frame_size)) + .broadcast(Eigen::DSizes(batch_size, 1)); + auto eigen_grad_gate_buf = framework::EigenMatrix::From( + *grad_gate_buf, framework::make_ddim({batch_size, frame_size})); + auto eigen_grad_gate = framework::EigenMatrix::From( + *grad_gate, framework::make_ddim({batch_size, frame_size})); + eigen_grad_gate_buf.device(place) = + eigen_grad_gate * eigen_grad_gate_mask_broadcast; + + if (reset_output_grad) { + Tensor grad_reset_output_mask; + create_tensor_by_list(context, &grad_reset_output_mask, {0, 0, 1}); + auto eigen_grad_reset_output_mask = framework::EigenMatrix::From( + grad_reset_output_mask, framework::make_ddim({3, 1})); + auto eigen_grad_reset_output_mask_broadcast = + eigen_grad_reset_output_mask + .broadcast(Eigen::DSizes(1, frame_size / 3)) + .reshape(Eigen::DSizes(frame_size)) + .broadcast(Eigen::DSizes(batch_size, 1)); + auto eigen_grad_reset_output = + framework::EigenMatrix::Reshape(*reset_output_grad, + reset_output_grad->dims().size() - 1) + .broadcast(Eigen::DSizes(1, 3, 1)) + .reshape(Eigen::DSizes(batch_size, frame_size)); + eigen_grad_gate_buf.device(place) = + eigen_grad_gate_buf + + eigen_grad_reset_output_mask_broadcast * eigen_grad_reset_output; + } +} + +template +struct GradLayer { + explicit GradLayer(const GradCellType& cell) : cell_(cell) {} + virtual ~GradLayer() {} + void run_rnn_grad_function( + const framework::ExecutionContext& context, + const platform::CPUDeviceContext& device_ctx, const Tensor* input, + Tensor* input_grad, const Tensor* sequence_length, + std::vector* init_h_unbind, std::vector* init_c_unbind, + std::vector* init_h_grad_unbind, + std::vector* init_c_grad_unbind, Tensor* layer_grad_gate_tensor, + std::vector* layer_gate_tensor_unbind, + std::vector* layer_grad_gate_tensor_unbind, + std::vector* layer_state_tensor_unbind, + std::vector* layer_act_state_tensor_unbind, + std::vector* output_tensor_unbind, + std::vector* output_grad_tensor_unbind, + const TensorList& last_h_grad_unbind, + const TensorList& last_c_grad_unbind, + const std::vector& parameter_lists, + std::vector* weight_list_grad, const int& layer_idx, + const int& time_step, const bool& has_sequence_length, + const bool& is_bidirec, const bool& is_reverse) { + const int& direction_num = is_bidirec ? 2 : 1; + const int& current_reverse_idx = is_reverse ? 1 : 0; + const int& current_layer_idx = + direction_num * layer_idx + current_reverse_idx; + int begin_idx = 0; + if (is_reverse) { + begin_idx = time_step; + } + + Tensor mask_matrix; + TensorList mask_tensor_list; + int mask_min_length = time_step; + if (has_sequence_length) { + mask_matrix.Resize(framework::make_ddim({time_step, input->dims()[1]})); + create_mask_matrix(context, sequence_length, &mask_matrix, is_reverse, + &mask_min_length); + mask_tensor_list = Unbind(mask_matrix); + } + // copy the last_h, last_c for swaping pointer + Tensor a, b; + Tensor* dynamic_grad_last_h = &a; + Tensor* dynamic_grad_last_c = &b; + dynamic_grad_last_h->Resize(last_h_grad_unbind[current_layer_idx].dims()); + dynamic_grad_last_h->mutable_data(context.GetPlace()); + framework::TensorCopy(last_h_grad_unbind[current_layer_idx], + context.GetPlace(), dynamic_grad_last_h); + if (last_c_grad_unbind.size() > 0) { + dynamic_grad_last_c->Resize(last_c_grad_unbind[current_layer_idx].dims()); + dynamic_grad_last_c->mutable_data(context.GetPlace()); + framework::TensorCopy(last_c_grad_unbind[current_layer_idx], + context.GetPlace(), dynamic_grad_last_c); + } else { + dynamic_grad_last_c = nullptr; + } + + Tensor c, d; + Tensor* dynamic_grad_pre_h = &c; + Tensor* dynamic_grad_pre_c = &d; + math::SetConstant zero; + if (init_h_grad_unbind->size() > 0) { + dynamic_grad_pre_h->ShareDataWith( + (*init_h_grad_unbind)[current_layer_idx]); + } else { + dynamic_grad_pre_h->Resize(dynamic_grad_last_h->dims()); + dynamic_grad_pre_h->mutable_data(context.GetPlace()); + zero(device_ctx, dynamic_grad_pre_h, static_cast(0.0)); + } + if (init_c_grad_unbind->size() > 0) { + dynamic_grad_pre_c->ShareDataWith( + (*init_c_grad_unbind)[current_layer_idx]); + } else { + if (is_lstm(context) || is_gru(context)) { + dynamic_grad_pre_c->Resize(dynamic_grad_last_h->dims()); + dynamic_grad_pre_c->mutable_data(context.GetPlace()); + if (is_gru(context)) { + dynamic_grad_last_c = dynamic_grad_pre_c; + } + } else { + dynamic_grad_pre_c = nullptr; + } + } + + if (is_reverse) { + // must be reverse the input, output, input_grad, output_grad + // the gate and grad_gate must be reverse + std::reverse(layer_gate_tensor_unbind->begin(), + layer_gate_tensor_unbind->end()); + std::reverse(layer_grad_gate_tensor_unbind->begin(), + layer_grad_gate_tensor_unbind->end()); + /* + if (has_sequence_length) { + std::reverse(mask_tensor_list.begin(), mask_tensor_list.end()); + }*/ + std::reverse(output_tensor_unbind->begin(), output_tensor_unbind->end()); + std::reverse(output_grad_tensor_unbind->begin(), + output_grad_tensor_unbind->end()); + } + + Tensor* weight_grad = + &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 1]); + weight_grad->mutable_data(context.GetPlace()); + zero(device_ctx, weight_grad, static_cast(0.0)); + + Tensor* pre_hidden = nullptr; + Tensor* pre_state = nullptr; + Tensor* hidden = nullptr; + Tensor grad_gate_buf; + TensorList grad_gate_buf_unbind; + if (is_gru(context)) { + grad_gate_buf.Resize(layer_grad_gate_tensor->dims()); + grad_gate_buf.mutable_data(context.GetPlace()); + grad_gate_buf_unbind = Unbind(grad_gate_buf); + } + for (int i = time_step - 1; i >= 0; --i) { + if (has_sequence_length) { + this->mask_preprocess(context, &(*output_grad_tensor_unbind)[i], + dynamic_grad_last_h, dynamic_grad_last_c, + dynamic_grad_pre_h, dynamic_grad_pre_c, + mask_tensor_list[i]); + } else { + this->preprocess(context, &(*output_grad_tensor_unbind)[i], + dynamic_grad_last_h); + } + hidden = &(*output_tensor_unbind)[i]; + if (i == 0) { + pre_hidden = &(*init_h_unbind)[current_layer_idx]; + if (init_c_unbind->size() > 0) { + pre_state = &(*init_c_unbind)[current_layer_idx]; + } + } else { + pre_hidden = &(*output_tensor_unbind)[i - 1]; + if (layer_state_tensor_unbind->size() > 0) { + pre_state = &(*layer_state_tensor_unbind)[begin_idx + i - 1]; + } + } + this->cell_( + context, &(*layer_gate_tensor_unbind)[i], + &(*layer_state_tensor_unbind)[begin_idx + i], + &(*layer_act_state_tensor_unbind)[begin_idx + i], hidden, + &(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]), + pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c, + &(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h, + dynamic_grad_pre_c, &grad_gate_buf_unbind[i], + &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]), + mask_tensor_list[i], has_sequence_length); + SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h); + SwapPoniter(&dynamic_grad_last_c, &dynamic_grad_pre_c); + } + // postproces for gradient for w_hi, X, bias_hi, bias_hh + this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad, + parameter_lists[layer_idx], + &((*weight_list_grad)[layer_idx]), &grad_gate_buf, + is_reverse); + + // copy the gradient to init_c init_h + if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) { + framework::TensorCopy(*dynamic_grad_last_h, context.GetPlace(), + &((*init_h_grad_unbind)[current_layer_idx])); + } + if ((*init_c_grad_unbind).size() > 0 && time_step % 2 == 0) { + framework::TensorCopy(*dynamic_grad_last_c, context.GetPlace(), + &((*init_c_grad_unbind)[current_layer_idx])); + } + } + + virtual void operator()( + const framework::ExecutionContext& context, const Tensor* input, + const Tensor* output, const TensorList& init_h_unbind, + const TensorList& init_c_unbind, const TensorList& last_h_grad_unbind, + const TensorList& last_c_grad_unbind, + const TensorList& gate_tensor_unbind, + const TensorList& state_tensor_unbind, + const TensorList& act_state_tensor_unbind, const Tensor* output_grad, + const std::vector& parameter_lists, + const Tensor* sequence_length, Tensor* input_grad, + TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind, + const std::vector& weight_list_grad, const int& layer_idx, + const int& gate_num) {} + void preprocess(const framework::ExecutionContext& context, + const Tensor* grad_output, Tensor* grad_last_h) { + auto& place = *context.template device_context() + .eigen_device(); + auto eigen_grad_output = framework::EigenMatrix::Reshape( + *grad_output, grad_output->dims().size() - 1); + auto eigen_grad_last_h = framework::EigenMatrix::Reshape( + *grad_last_h, grad_last_h->dims().size() - 1); + // the output gradient contribute the gradient to last_h + eigen_grad_last_h.device(place) = eigen_grad_last_h + eigen_grad_output; + } + + void mask_preprocess(const framework::ExecutionContext& context, + const Tensor* grad_output, Tensor* grad_last_h, + Tensor* grad_last_c, Tensor* grad_pre_h, + Tensor* grad_pre_c, const Tensor& mask_tensor) { + auto& place = *context.template device_context() + .eigen_device(); + auto eigen_mask = framework::EigenMatrix::From( + mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1})); + auto eigen_mask_broadcast = + eigen_mask.broadcast(Eigen::DSizes(1, grad_output->dims()[2])); + + auto eigen_grad_last_h = framework::EigenMatrix::Reshape( + *grad_last_h, grad_last_h->dims().size() - 1); + auto eigen_grad_pre_h = framework::EigenMatrix::Reshape( + *grad_pre_h, grad_pre_h->dims().size() - 1); + auto eigen_grad_output = framework::EigenMatrix::Reshape( + *grad_output, grad_output->dims().size() - 1); + eigen_grad_last_h.device(place) = + eigen_grad_last_h + eigen_grad_output * eigen_mask_broadcast; + eigen_grad_pre_h.device(place) = + (1 - eigen_mask_broadcast) * eigen_grad_last_h; + eigen_grad_last_h.device(place) = eigen_mask_broadcast * eigen_grad_last_h; + + if (grad_last_c && grad_pre_c && is_lstm(context)) { + auto eigen_grad_last_c = framework::EigenMatrix::Reshape( + *grad_last_c, grad_last_c->dims().size() - 1); + auto eigen_grad_pre_c = framework::EigenMatrix::Reshape( + *grad_pre_c, grad_pre_c->dims().size() - 1); + eigen_grad_pre_c.device(place) = + (1 - eigen_mask_broadcast) * eigen_grad_last_c; + eigen_grad_last_c.device(place) = + eigen_mask_broadcast * eigen_grad_last_c; + } + } + + void postprocess(const framework::ExecutionContext& context, + const Tensor& grad_gate, const Tensor& input, + Tensor* input_grad, const TensorList& parameters, + TensorList* grad_parameters, Tensor* grad_gate_buf, + const int& is_reverse) { + // we get the grad_gate step by step, and need to bradocast the grad to the + // grad_w_hi, grad_bias_hi, grad_bias_hh + int begin_idx = 0; + if (is_reverse) { + begin_idx = 4; + } + auto& device_ctx = + context.template device_context(); + auto blas = math::GetBlas(device_ctx); + + // calc the gradient for the w_hi + auto mat_dim_out_grad = + math::CreateMatrixDescriptor(grad_gate.dims(), 0, true); + auto mat_dim_input = math::CreateMatrixDescriptor(input.dims(), 0, false); + mat_dim_out_grad.width_ *= mat_dim_out_grad.batch_size_; + mat_dim_out_grad.batch_size_ = 0; + mat_dim_input.height_ *= mat_dim_input.batch_size_; + mat_dim_input.batch_size_ = 0; + blas.MatMul(grad_gate, mat_dim_out_grad, input, mat_dim_input, + static_cast(1.0), &((*grad_parameters)[begin_idx + 0]), + T(0)); + + // calc the gradient for the X + auto mat_dim_out_grad_new = + math::CreateMatrixDescriptor(grad_gate.dims(), 0, false); + mat_dim_out_grad_new.height_ *= mat_dim_out_grad_new.batch_size_; + mat_dim_out_grad_new.batch_size_ = 0; + auto mat_dim_parameter = + math::CreateMatrixDescriptor(parameters[0].dims(), 0, false); + blas.MatMul(grad_gate, mat_dim_out_grad_new, parameters[begin_idx + 0], + mat_dim_parameter, static_cast(1.0), input_grad, T(1)); + + // calc the gradient of Bias_hi, Bias_hh + math::ColwiseSum col_sum; + Tensor tmp_grad_gate; + tmp_grad_gate.ShareDataWith(grad_gate); + tmp_grad_gate.Resize( + {grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]}); + col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2])); + // Bias_hh + if (is_gru(context)) { + grad_gate_buf->Resize(tmp_grad_gate.dims()); + col_sum(device_ctx, *grad_gate_buf, &((*grad_parameters)[begin_idx + 3])); + } else { + col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3])); + } + } + GradCellType cell_; +}; + +template +struct SingleGradLayer : GradLayer { + // explicit SingleGradLayer(GradCellType& cell) : cell_(cell) {} + explicit SingleGradLayer(const GradCellType& cell) + : GradLayer(cell) {} + virtual ~SingleGradLayer() {} + void operator()( + const framework::ExecutionContext& context, const Tensor* input, + const Tensor* output, std::vector* init_h_unbind, + std::vector* init_c_unbind, const TensorList& last_h_grad_unbind, + const TensorList& last_c_grad_unbind, + const TensorList& gate_tensor_unbind, + const TensorList& state_tensor_unbind, + const TensorList& act_state_tensor_unbind, const Tensor* output_grad, + const std::vector& parameter_lists, + const Tensor* sequence_length, Tensor* input_grad, + TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind, + std::vector* weight_list_grad, const int& layer_idx, + const int& gate_num) { + auto& device_ctx = + context.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + const bool& is_bidirec = context.Attr("is_bidirec"); + const int& time_step = input->dims()[0]; + const int& batch_size = input->dims()[1]; + const int& direction_num = is_bidirec ? 2 : 1; + const int& hidden_size = context.Attr("hidden_size"); + + // in this section, create the gate_state_grad for the postprocess calculate + // ubind the output, the output from [time_step, batch_size, hidden_size] + auto output_tensor_unbind = Unbind(*output); + auto output_grad_tensor_unbind = Unbind(*output_grad); + auto layer_gate_tensor = gate_tensor_unbind[layer_idx]; + layer_gate_tensor.Resize( + {time_step * direction_num, batch_size, hidden_size * gate_num}); + auto layer_gate_tensor_unbind = Unbind(layer_gate_tensor); + // the gate_tensor and the grad_gate_tensor must be unbind + Tensor layer_grad_gate_tensor; + layer_grad_gate_tensor.Resize(layer_gate_tensor.dims()); + layer_grad_gate_tensor.mutable_data(context.GetPlace()); + auto layer_grad_gate_tensor_unbind = Unbind(layer_grad_gate_tensor); + + Tensor layer_state_tensor; + TensorList layer_state_tensor_unbind; + if (state_tensor_unbind.size() > 0) { + layer_state_tensor = state_tensor_unbind[layer_idx]; + layer_state_tensor.Resize( + {time_step * direction_num, batch_size, hidden_size}); + layer_state_tensor_unbind = Unbind(layer_state_tensor); + } + + Tensor layer_act_state_tensor; + TensorList layer_act_state_tensor_unbind; + if (act_state_tensor_unbind.size() > 0) { + layer_act_state_tensor = act_state_tensor_unbind[layer_idx]; + layer_act_state_tensor.Resize( + {time_step * direction_num, batch_size, hidden_size}); + layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor); + } + const bool& has_sequence_length = sequence_length == nullptr ? false : true; + this->run_rnn_grad_function( + context, device_ctx, input, input_grad, sequence_length, init_h_unbind, + init_c_unbind, init_h_grad_unbind, init_c_grad_unbind, + &layer_grad_gate_tensor, &layer_gate_tensor_unbind, + &layer_grad_gate_tensor_unbind, &layer_state_tensor_unbind, + &layer_act_state_tensor_unbind, &output_tensor_unbind, + &output_grad_tensor_unbind, last_h_grad_unbind, last_c_grad_unbind, + parameter_lists, weight_list_grad, layer_idx, time_step, + has_sequence_length, is_bidirec, false); + } +}; +template +void split_tensor_at_last_dim(const framework::ExecutionContext& context, + const platform::CPUDeviceContext& dev_ctx, + const Tensor* output, + std::vector* output_vec, + const int& axis) { + std::vector shape_refer; + (*output_vec)[0]->Resize( + {output->dims()[0], output->dims()[1], output->dims()[2] / 2}); + (*output_vec)[0]->mutable_data(context.GetPlace()); + (*output_vec)[1]->Resize( + {output->dims()[0], output->dims()[1], output->dims()[2] / 2}); + (*output_vec)[1]->mutable_data(context.GetPlace()); + shape_refer.emplace_back((*output_vec)[0]); + shape_refer.emplace_back((*output_vec)[1]); + math::SplitFunctor functor; + functor(dev_ctx, *output, shape_refer, axis, output_vec); +} + +template +struct BidirGradLayer : GradLayer { + explicit BidirGradLayer(const GradCellType& cell) + : GradLayer(cell) {} + virtual ~BidirGradLayer() {} + void operator()( + const framework::ExecutionContext& context, const Tensor* input, + const Tensor* output, std::vector* init_h_unbind, + std::vector* init_c_unbind, const TensorList& last_h_grad_unbind, + const TensorList& last_c_grad_unbind, + const TensorList& gate_tensor_unbind, + const TensorList& state_tensor_unbind, + const TensorList& act_state_tensor_unbind, const Tensor* output_grad, + const std::vector& parameter_lists, + const Tensor* sequence_length, Tensor* input_grad, + TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind, + std::vector* weight_list_grad, const int& layer_idx, + const int& gate_num) { + const bool& is_bidirec = context.Attr("is_bidirec"); + const int& time_step = input->dims()[0]; + const int& batch_size = input->dims()[1]; + const int& direction_num = is_bidirec ? 2 : 1; + const int& hidden_size = context.Attr("hidden_size"); + // split the output two tensor to output_forward, output_backward + auto& device_ctx = + context.template device_context(); + math::SetConstant zero; + zero(device_ctx, input_grad, static_cast(0.0)); + + std::vector output_vec; + Tensor forward_output; + Tensor backward_output; + std::vector forward_output_tensor_unbind; + std::vector backward_output_tensor_unbind; + // in the last layer, we will use the output as the last hidden + // the output just the concat the forward hidden, backward hidden, so just + // split it + // in other layer, we just split the hidden in the rows + output_vec.emplace_back(&forward_output); + output_vec.emplace_back(&backward_output); + split_tensor_at_last_dim(context, device_ctx, output, &output_vec, 2); + forward_output_tensor_unbind = Unbind(*(output_vec[0])); + backward_output_tensor_unbind = Unbind(*(output_vec[1])); + + std::vector output_grad_vec; + Tensor grad_forward_output; + Tensor grad_backward_output; + output_grad_vec.emplace_back(&grad_forward_output); + output_grad_vec.emplace_back(&grad_backward_output); + split_tensor_at_last_dim(context, device_ctx, output_grad, + &output_grad_vec, 2); + auto forward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[0])); + auto backward_output_grad_tensor_unbind = Unbind(*(output_grad_vec[1])); + + // the gate_tensor and the grad_gate_tensor must be unbind + auto layer_gate_tensor = gate_tensor_unbind[layer_idx]; + layer_gate_tensor.Resize( + {time_step * 2, batch_size, hidden_size * gate_num}); + auto layer_forward_gate_tensor = layer_gate_tensor.Slice(0, time_step); + auto layer_backward_gate_tensor = + layer_gate_tensor.Slice(time_step, 2 * time_step); + auto layer_forward_gate_tensor_unbind = Unbind(layer_forward_gate_tensor); + auto layer_backward_gate_tensor_unbind = Unbind(layer_backward_gate_tensor); + + Tensor layer_grad_gate_tensor; + layer_grad_gate_tensor.Resize(layer_gate_tensor.dims()); + layer_grad_gate_tensor.mutable_data(context.GetPlace()); + zero(device_ctx, &layer_grad_gate_tensor, static_cast(0.0)); + auto layer_forward_grad_gate_tensor = + layer_grad_gate_tensor.Slice(0, time_step); + auto layer_backward_grad_gate_tensor = + layer_grad_gate_tensor.Slice(time_step, 2 * time_step); + auto layer_forward_grad_gate_tensor_unbind = + Unbind(layer_forward_grad_gate_tensor); + auto layer_backward_grad_gate_tensor_unbind = + Unbind(layer_backward_grad_gate_tensor); + + Tensor layer_state_tensor; + TensorList layer_state_tensor_unbind; + if (state_tensor_unbind.size() > 0) { + layer_state_tensor = state_tensor_unbind[layer_idx]; + layer_state_tensor.Resize( + {time_step * direction_num, batch_size, hidden_size}); + layer_state_tensor_unbind = Unbind(layer_state_tensor); + } + + Tensor layer_act_state_tensor; + TensorList layer_act_state_tensor_unbind; + if (act_state_tensor_unbind.size() > 0) { + layer_act_state_tensor = act_state_tensor_unbind[layer_idx]; + layer_act_state_tensor.Resize( + {time_step * direction_num, batch_size, hidden_size}); + layer_act_state_tensor_unbind = Unbind(layer_act_state_tensor); + } + const bool& has_sequence_length = sequence_length == nullptr ? false : true; + + this->run_rnn_grad_function( + context, device_ctx, input, input_grad, sequence_length, init_h_unbind, + init_c_unbind, init_h_grad_unbind, init_c_grad_unbind, + &layer_forward_grad_gate_tensor, &layer_forward_gate_tensor_unbind, + &layer_forward_grad_gate_tensor_unbind, &layer_state_tensor_unbind, + &layer_act_state_tensor_unbind, &forward_output_tensor_unbind, + &forward_output_grad_tensor_unbind, last_h_grad_unbind, + last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx, + time_step, has_sequence_length, is_bidirec, false); + + this->run_rnn_grad_function( + context, device_ctx, input, input_grad, sequence_length, init_h_unbind, + init_c_unbind, init_h_grad_unbind, init_c_grad_unbind, + &layer_backward_grad_gate_tensor, &layer_backward_gate_tensor_unbind, + &layer_backward_grad_gate_tensor_unbind, &layer_state_tensor_unbind, + &layer_act_state_tensor_unbind, &backward_output_tensor_unbind, + &backward_output_grad_tensor_unbind, last_h_grad_unbind, + last_c_grad_unbind, parameter_lists, weight_list_grad, layer_idx, + time_step, has_sequence_length, is_bidirec, true); + } +}; + +template +void backup_tensor(const framework::ExecutionContext& context, Tensor* dst, + Tensor* src) { + auto& device_ctx = + context.template device_context(); + dst->Resize(src->dims()); + dst->mutable_data(context.GetPlace()); + framework::TensorCopy(*src, device_ctx.GetPlace(), device_ctx, dst); +} + +template +struct GradCell { + virtual ~GradCell() {} + virtual void operator()(const framework::ExecutionContext& context, + Tensor* gate_tensor, Tensor* state_tensor, + Tensor* act_state_tensor, Tensor* hidden_tensor, + const Tensor* weight_hh, Tensor* pre_hidden, + Tensor* pre_state, Tensor* grad_hidden, + Tensor* grad_state, Tensor* grad_gate, + Tensor* grad_weight_hh, Tensor* grad_pre_hidden, + Tensor* grad_pre_state, Tensor* grad_gate_buf, + Tensor* grad_bias_hh, const Tensor& mask_tensor, + bool has_sequence_length) const {} + virtual void update_pre_hidden_grad( + const framework::ExecutionContext& context, Tensor* grad_gate, + const Tensor* weight_hh, Tensor* grad_pre_hidden, + Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state, + Tensor* grad_pre_state_bak, Tensor* grad_gate_buf, + const Tensor& mask_tensor, bool has_sequence_length) const { + auto& device_ctx = + context.template device_context(); + auto blas = math::GetBlas(device_ctx); + T beta = 0; + Tensor* grad_gate_tmp = grad_gate; + if (is_gru(context)) { + beta = 1.0; + grad_gate_tmp = grad_gate_buf; + } + + auto mat_dim_a = + math::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false); + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, false); + blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b, + static_cast(1.0), grad_pre_hidden, beta); + + if (has_sequence_length) { + auto& place = + *context.template device_context() + .eigen_device(); + auto eigen_mask = framework::EigenMatrix::From( + mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1})); + auto eigen_mask_broadcast = eigen_mask.broadcast( + Eigen::DSizes(1, grad_pre_hidden->dims()[2])); + auto eigen_grad_pre_hidden = framework::EigenMatrix::Reshape( + *grad_pre_hidden, grad_pre_hidden->dims().size() - 1); + auto eigen_grad_pre_hidden_bak = framework::EigenMatrix::Reshape( + *grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1); + eigen_grad_pre_hidden.device(place) = + (1 - eigen_mask_broadcast) * eigen_grad_pre_hidden_bak + + eigen_grad_pre_hidden * eigen_mask_broadcast; + if (grad_pre_state) { + auto eigen_grad_pre_state = framework::EigenMatrix::Reshape( + *grad_pre_state, grad_pre_state->dims().size() - 1); + auto eigen_grad_pre_state_bak = framework::EigenMatrix::Reshape( + *grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1); + eigen_grad_pre_state.device(place) = + (1 - eigen_mask_broadcast) * eigen_grad_pre_state_bak + + eigen_grad_pre_state * eigen_mask_broadcast; + } + } + } + + virtual void update_weight_hh_grad(const framework::ExecutionContext& context, + Tensor* grad_gate, Tensor* pre_hidden, + Tensor* grad_weight_hh, + Tensor* grad_gate_buf) const { + auto& device_ctx = + context.template device_context(); + auto blas = math::GetBlas(device_ctx); + auto mat_dim_c = math::CreateMatrixDescriptor(grad_gate->dims(), 0, true); + mat_dim_c.height_ *= mat_dim_c.batch_size_; + mat_dim_c.batch_size_ = 0; + auto mat_dim_d = math::CreateMatrixDescriptor(pre_hidden->dims(), 0, false); + mat_dim_d.height_ *= mat_dim_d.batch_size_; + mat_dim_d.batch_size_ = 0; + Tensor* grad_gate_tmp = grad_gate; + if (is_gru(context)) { + grad_gate_tmp = grad_gate_buf; + } + blas.MatMul(*grad_gate_tmp, mat_dim_c, *pre_hidden, mat_dim_d, + static_cast(1.0), grad_weight_hh, static_cast(1.0)); + } +}; + +template class EigenActivationBackwardFunctor> +struct SimpleRNNGradCell : GradCell { + void operator()(const framework::ExecutionContext& context, + Tensor* gate_tensor, Tensor* state_tensor, + Tensor* act_state_tensor, Tensor* hidden_tensor, + const Tensor* weight_hh, Tensor* pre_hidden, + Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, + Tensor* grad_gate, Tensor* grad_weight_hh, + Tensor* grad_pre_hidden, Tensor* grad_pre_state, + Tensor* grad_gate_buf, Tensor* grad_bias_hh, + const Tensor& mask_tensor, + bool has_sequence_length) const override { + auto& device_ctx = + context.template device_context(); + Tensor grad_pre_hidden_bak; + if (has_sequence_length) { + backup_tensor(context, &grad_pre_hidden_bak, grad_pre_hidden); + } + // h = act(z) + // update dz + auto dz = EigenVector::Flatten( + GET_DATA_SAFELY(grad_gate, "Output", "dz", "Grad")); + auto dh = EigenVector::Flatten( + GET_DATA_SAFELY(grad_hidden, "Input", "dh", "Grad")); + auto h = EigenVector::Flatten( + GET_DATA_SAFELY(hidden_tensor, "Input", "h", "Value")); + // useless, but need this argument to execute functor + auto z = EigenVector::Flatten( + GET_DATA_SAFELY(gate_tensor, "Input", "z", "Value")); + + auto* place = device_ctx.eigen_device(); + EigenActivationBackwardFunctor functor; + functor(*place, z, h, dh, dz); + + // update grad_weight_hh, grad_pre_hidden + this->update_pre_hidden_grad( + context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak, + nullptr, nullptr, grad_gate_buf, mask_tensor, has_sequence_length); + this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh, + grad_gate_buf); + } +}; + +template +struct GRUGradCell : GradCell { + void operator()(const framework::ExecutionContext& context, + Tensor* gate_tensor, Tensor* state_tensor, + Tensor* act_state_tensor, Tensor* hidden_tensor, + const Tensor* weight_hh, Tensor* pre_hidden, + Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, + Tensor* grad_gate, Tensor* grad_weight_hh, + Tensor* grad_pre_hidden, Tensor* grad_pre_state, + Tensor* grad_gate_buf, Tensor* grad_bias_hh, + const Tensor& mask_tensor, + bool has_sequence_length) const override { + auto& device_ctx = + context.template device_context(); + size_t frame_size = pre_hidden->dims()[2]; + size_t batch_size = pre_hidden->dims()[1]; + Tensor grad_pre_hidden_bak; + if (has_sequence_length) { + backup_tensor(context, &grad_pre_hidden_bak, grad_pre_hidden); + } + // zero pre_hidden + math::SetConstant zero; + zero(device_ctx, grad_pre_hidden, static_cast(0.0)); + math::GRUMetaValue gru_value; + math::GRUMetaGrad gru_grad; + gru_value.gate_value = gate_tensor->data(); + gru_value.prev_out_value = pre_hidden->data(); + gru_value.reset_output_value = state_tensor->data(); + + gru_grad.gate_grad = grad_gate->data(); + gru_grad.reset_output_grad = grad_state->data(); + gru_grad.prev_out_grad = grad_pre_hidden->data(); + gru_grad.output_grad = grad_hidden->data(); + gru_grad.gate_weight_grad = grad_weight_hh->data(); + gru_grad.state_weight_grad = + grad_weight_hh->data() + 2 * frame_size * frame_size; + gru_grad.state_bias_grad = grad_bias_hh->data() + 2 * frame_size; + + auto act_gate = math::detail::GetActivationType("sigmoid_v2"); + auto act_node = math::detail::GetActivationType("tanh_v2"); + math::GRUUnitGradFunctorV2::compute( + device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node, + act_gate); + + make_grad_gate_buf(context, grad_gate, grad_gate_buf, grad_state); + + this->update_pre_hidden_grad( + context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak, + nullptr, nullptr, grad_gate_buf, mask_tensor, has_sequence_length); + this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh, + grad_gate_buf); + } +}; + +template +struct LSTMGradCell : GradCell { + void operator()(const framework::ExecutionContext& context, + Tensor* gate_tensor, Tensor* state_tensor, + Tensor* act_state_tensor, Tensor* hidden_tensor, + const Tensor* weight_hh, Tensor* pre_hidden, + Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, + Tensor* grad_gate, Tensor* grad_weight_hh, + Tensor* grad_pre_hidden, Tensor* grad_pre_state, + Tensor* grad_gate_buf, Tensor* grad_bias_hh, + const Tensor& mask_tensor, + bool has_sequence_length) const override { + auto& device_ctx = + context.template device_context(); + size_t frame_size = state_tensor->dims()[2]; + size_t batch_size = state_tensor->dims()[1]; + + Tensor grad_pre_hidden_bak; + Tensor grad_pre_state_bak; + if (has_sequence_length) { + backup_tensor(context, &grad_pre_hidden_bak, grad_pre_hidden); + backup_tensor(context, &grad_pre_state_bak, grad_pre_state); + } + + math::LstmMetaValue lstm_value; + math::LstmMetaGrad lstm_grad; + create_lstm_value(&lstm_value); + create_lstm_grad(&lstm_grad); + lstm_value.gate_value = gate_tensor->data(); + lstm_value.state_value = state_tensor->data(); + lstm_value.state_active_value = act_state_tensor->data(); + lstm_value.prev_state_value = pre_state->data(); + + lstm_grad.state_grad = grad_state->data(); + lstm_grad.gate_grad = grad_gate->data(); + lstm_grad.output_grad = grad_hidden->data(); + lstm_grad.prev_state_grad = grad_pre_state->data(); + + lstm_value.output_value = nullptr; + lstm_grad.state_active_grad = nullptr; + + auto gate_act = math::detail::GetActivationType("sigmoid_v2"); + auto state_act = math::detail::GetActivationType("tanh_v2"); + auto cand_act = math::detail::GetActivationType("tanh_v2"); + + T cell_clip = 0.0; + math::LstmUnitGradFunctor::compute( + device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip, + gate_act, state_act, cand_act, false); + this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden, + &grad_pre_hidden_bak, grad_pre_state, + &grad_pre_state_bak, grad_gate_buf, + mask_tensor, has_sequence_length); + this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh, + grad_gate_buf); + } +}; + +template class SingleGradLayerT, + template class BidirGradLayerT, typename T> +void RnnGradFunc(const framework::ExecutionContext& context, + const int& gate_num) { + // get the tensor pointer for the input + auto* input = context.Input("Input"); + auto weight_list = context.MultiInput("WeightList"); + auto pre_state = context.MultiInput("PreState"); + + const Tensor* init_h = pre_state[0]; + const Tensor* init_c = nullptr; + if (is_lstm(context)) { + init_c = pre_state[1]; + } + auto* reserve_state = context.Input("Reserve"); + auto* dropout_state = context.Input("DropoutState"); + auto* output = context.Input("Out"); + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto state_grad = context.MultiInput(framework::GradVarName("State")); + const Tensor* last_h_grad = state_grad[0]; + const Tensor* last_c_grad = nullptr; + if (is_lstm(context)) { + last_c_grad = state_grad[1]; + } + + bool has_seq_length = context.HasInput("SequenceLength"); + const Tensor* sequence_length = nullptr; + if (has_seq_length) { + sequence_length = context.Input("SequenceLength"); + } + + // get the tensor pointer for the output + auto* input_grad = context.Output(framework::GradVarName("Input")); + auto weight_grad_list = context.MultiOutput( + framework::GradVarName("WeightList")); + auto pre_state_grad = + context.MultiOutput(framework::GradVarName("PreState")); + Tensor* init_h_grad = nullptr; + Tensor* init_c_grad = nullptr; + if (pre_state_grad.size() > 0) { // has gradient + init_h_grad = pre_state_grad[0]; + if (is_lstm(context)) { + init_c_grad = pre_state_grad[1]; + } + } + + // get the attributes for the calcluate + const int& num_layers = context.Attr("num_layers"); + const bool& is_bidirec = context.Attr("is_bidirec"); + const float& dropout_prob = context.Attr("dropout_prob"); + const bool& is_test = context.Attr("is_test"); + + // get the input_size, batch_size, time_step, hidden_size + const int& time_step = input->dims()[0]; + const int& batch_size = input->dims()[1]; + const int& hidden_size = context.Attr("hidden_size"); + const int& direction_num = is_bidirec ? 2 : 1; + // allocate the memory and initization the input_grad + Tensor input_grad_value; + if (!input_grad) { + input_grad = &input_grad_value; + } + input_grad->mutable_data(input->dims(), context.GetPlace()); + + if (init_h_grad) { + init_h_grad->mutable_data(init_h->dims(), context.GetPlace()); + } + if (init_c_grad) { + init_c_grad->mutable_data(init_c->dims(), context.GetPlace()); + } + + // reset the parameter to sorted order and allocate the memory + std::vector parameter_lists; + parameter_lists.reserve(num_layers); + reset_parameter_vector(weight_list, num_layers, gate_num, is_bidirec, + ¶meter_lists); + + for (unsigned int i = 0; i < weight_grad_list.size(); ++i) { + weight_grad_list[i]->mutable_data(context.GetPlace()); + } + std::vector parameter_lists_grad; + parameter_lists_grad.reserve(num_layers); + reset_parameter_vector(weight_grad_list, num_layers, gate_num, is_bidirec, + ¶meter_lists_grad); + + // resolve the state of reverse_state + Tensor gate_tensor; + Tensor state_tensor; + Tensor act_state_tensor; + Tensor hidden_tensor; + SplitReserveData(context, reserve_state, &gate_tensor, &state_tensor, + &act_state_tensor, &hidden_tensor, direction_num, time_step, + batch_size, hidden_size, gate_num, num_layers); + int gate_num_tmp = gate_num; + if (gate_num == 0) { + gate_num_tmp = 1; + } + gate_tensor.Resize({num_layers, time_step * direction_num, batch_size, + hidden_size * gate_num_tmp}); + if (state_tensor.numel() > 0) { + state_tensor.Resize( + {num_layers, time_step * direction_num, batch_size, hidden_size}); + } + if (act_state_tensor.numel() > 0) { + act_state_tensor.Resize( + {num_layers, time_step * direction_num, batch_size, hidden_size}); + } + if (num_layers > 1) { + hidden_tensor.Resize( + {num_layers - 1, time_step, batch_size, hidden_size * direction_num}); + } + // unbind + auto last_h_grad_unbind = Unbind(*last_h_grad); + auto gate_tensor_unbind = Unbind(gate_tensor); + TensorList last_c_grad_unbind; + if (last_c_grad) { + last_c_grad_unbind = Unbind(*last_c_grad); + } + + TensorList init_h_unbind, init_c_unbind; + TensorList init_h_grad_unbind, init_c_grad_unbind; + TensorList state_tensor_unbind, act_state_tensor_unbind; + TensorList hidden_tensor_unbind; + + init_h_unbind = Unbind(*init_h); + if (init_c) { + init_c_unbind = Unbind(*init_c); + } + + if (init_h_grad != nullptr) { + init_h_grad_unbind = Unbind(*init_h_grad); + } + if (init_c_grad != nullptr) { + init_c_grad_unbind = Unbind(*init_c_grad); + } + if (state_tensor.numel() > 0) { + state_tensor_unbind = Unbind(state_tensor); + } + if (act_state_tensor.numel() > 0) { + act_state_tensor_unbind = Unbind(act_state_tensor); + } + if (num_layers > 1) { + hidden_tensor_unbind = Unbind(hidden_tensor); + } + // squeeze the hidden first dim + for (unsigned int i = 0; i < hidden_tensor_unbind.size(); i++) { + hidden_tensor_unbind[i].Resize( + framework::slice_ddim(hidden_tensor_unbind[i].dims(), 1, + hidden_tensor_unbind[i].dims().size())); + } + // add the output tensor to the hidden vector + Tensor tmp; + hidden_tensor_unbind.emplace_back(tmp); + hidden_tensor_unbind[num_layers - 1].ShareDataWith(*output); + + GradCellType cell; + Tensor layer_input; + Tensor layer_output; + Tensor* layer_input_grad_holder = nullptr; + Tensor tmp_out; + tmp_out.ShareDataWith(*output_grad); + Tensor* layer_output_grad_holder = &tmp_out; + Tensor input_grad_temp; + Tensor output_grad_temp; + + bool has_allocate_mem = false; + for (int i = num_layers - 1; i >= 0; --i) { + // the layer input output had saved, just use the data + if (i > 0) { + layer_input.ShareDataWith(hidden_tensor_unbind[i - 1]); + } else { + layer_input.ShareDataWith(*input); + } + layer_output.ShareDataWith(hidden_tensor_unbind[i]); + if (num_layers == 1) { + layer_input_grad_holder = input_grad; + } else { + if (i == num_layers - 1) { + input_grad_temp.Resize(layer_input.dims()); + input_grad_temp.mutable_data(context.GetPlace()); + layer_input_grad_holder = &input_grad_temp; + } + } + if (is_bidirec) { + BidirGradLayerT layer(cell); + layer(context, &layer_input, &layer_output, &init_h_unbind, + &init_c_unbind, last_h_grad_unbind, last_c_grad_unbind, + gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind, + layer_output_grad_holder, parameter_lists, sequence_length, + layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind, + ¶meter_lists_grad, i, gate_num_tmp); + } else { + SingleGradLayerT layer(cell); + layer(context, &layer_input, &layer_output, &init_h_unbind, + &init_c_unbind, last_h_grad_unbind, last_c_grad_unbind, + gate_tensor_unbind, state_tensor_unbind, act_state_tensor_unbind, + layer_output_grad_holder, parameter_lists, sequence_length, + layer_input_grad_holder, &init_h_grad_unbind, &init_c_grad_unbind, + ¶meter_lists_grad, i, gate_num_tmp); + } + + // calcluate the dropout gradient for the layer_input_grad_holder + // dropout_state save in the forward process + if (i > 0) { + if ((!is_test) && (dropout_prob != 0)) { + dropout_cpu_grad_function_inplace(context, layer_input_grad_holder, + dropout_state, dropout_prob); + } + } + + if (i - 1 == 0) { + layer_output_grad_holder = input_grad; + } else { + if (!has_allocate_mem) { + output_grad_temp.Resize(layer_input_grad_holder->dims()); + output_grad_temp.mutable_data(context.GetPlace()); + layer_output_grad_holder = &output_grad_temp; + has_allocate_mem = true; + } + } + SwapPoniter(&layer_input_grad_holder, &layer_output_grad_holder); + } +} + +template +class RNNCPUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int gate_num = 4; + if (is_lstm(ctx)) { + RnnGradFunc, SingleGradLayer, BidirGradLayer, T>( + ctx, gate_num); + } else if (is_gru(ctx)) { + gate_num = 3; + RnnGradFunc, SingleGradLayer, BidirGradLayer, T>(ctx, + gate_num); + // run gru + } else if (is_rnn_relu(ctx)) { + gate_num = 1; + RnnGradFunc, SingleGradLayer, + BidirGradLayer, T>(ctx, gate_num); + // run rnn + } else if (is_rnn_tanh(ctx)) { + gate_num = 1; + RnnGradFunc, SingleGradLayer, + BidirGradLayer, T>(ctx, gate_num); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py index cfb4bb69a2..cab858f048 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py @@ -65,10 +65,18 @@ class TestLstm(unittest.TestCase): paddle.jit.ProgramTranslator().enable(True) net = Net(12, 2) x = paddle.randn((2, 10, 12)) + x.stop_gradient = False dygraph_out = net(x) + loss = paddle.mean(dygraph_out) + sgd = paddle.optimizer.SGD(learning_rate=0.001, + parameters=net.parameters()) + loss.backward() + sgd.step() # switch eval mode firstly net.eval() - + x = paddle.randn((2, 10, 12)) + dygraph_out = net(x) + dropout_out = net(x) net = paddle.jit.to_static( net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) paddle.jit.save(net, 'simple_lstm') @@ -106,6 +114,14 @@ class TestSaveInEvalMode(unittest.TestCase): def test_save_in_eval(self): paddle.jit.ProgramTranslator().enable(True) net = LinearNet() + x = paddle.randn((2, 10)) + x.stop_gradient = False + dygraph_out = net(x) + loss = paddle.mean(dygraph_out) + sgd = paddle.optimizer.SGD(learning_rate=0.001, + parameters=net.parameters()) + loss.backward() + sgd.step() # switch eval mode firstly net.eval() # save directly @@ -129,6 +145,14 @@ class TestEvalAfterSave(unittest.TestCase): def test_eval_after_save(self): x = paddle.randn((2, 10, 12)).astype('float32') net = Net(12, 2) + x.stop_gradient = False + dy_out = net(x) + loss = paddle.mean(dy_out) + sgd = paddle.optimizer.SGD(learning_rate=0.001, + parameters=net.parameters()) + loss.backward() + sgd.step() + x = paddle.randn((2, 10, 12)).astype('float32') dy_out = net(x) # save model paddle.jit.save(net, 'jit.save/lstm', input_spec=[x]) diff --git a/python/paddle/fluid/tests/unittests/rnn/convert.py b/python/paddle/fluid/tests/unittests/rnn/convert.py index 02f10694a4..645f67fca2 100644 --- a/python/paddle/fluid/tests/unittests/rnn/convert.py +++ b/python/paddle/fluid/tests/unittests/rnn/convert.py @@ -49,3 +49,34 @@ def convert_params_for_net_static(np_net, paddle_net, place): paddle_layer.cell_fw, place) convert_params_for_cell_static(np_layer.cell_bw, paddle_layer.cell_bw, place) + + +def get_params_for_cell(np_cell, num_layers, idx): + state = np_cell.parameters + weight_list = [ + ('{}.weight_{}'.format(num_layers, idx), state['weight_ih']), + ('{}.weight_{}'.format(num_layers, idx + 1), state['weight_hh']) + ] + bias_list = [('{}.bias_{}'.format(num_layers, idx), state['bias_ih']), + ('{}.bias_{}'.format(num_layers, idx + 1), state['bias_hh'])] + return weight_list, bias_list + + +def get_params_for_net(np_net): + weight_list = [] + bias_list = [] + for layer_idx, np_layer in enumerate(np_net): + if hasattr(np_layer, "cell"): + weight, bias = get_params_for_cell(np_layer.cell, layer_idx, 0) + for w, b in zip(weight, bias): + weight_list.append(w) + bias_list.append(b) + else: + for count, cell in enumerate([np_layer.cell_fw, np_layer.cell_bw]): + weight, bias = get_params_for_cell(cell, layer_idx, count * 2) + for w, b in zip(weight, bias): + weight_list.append(w) + bias_list.append(b) + + weight_list.extend(bias_list) + return weight_list diff --git a/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py index 317be28da4..d9149b0628 100644 --- a/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py +++ b/python/paddle/fluid/tests/unittests/rnn/rnn_numpy.py @@ -33,11 +33,16 @@ class LayerListMixin(LayerMixin): class SimpleRNNCell(LayerMixin): - def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): + def __init__(self, + input_size, + hidden_size, + bias=True, + nonlinearity="RNN_TANH", + dtype="float64"): self.input_size = input_size self.hidden_size = hidden_size self.bias = bias - if nonlinearity == 'tanh': + if nonlinearity == 'RNN_TANH': self.nonlinearity = np.tanh else: self.nonlinearity = lambda x: np.maximum(x, 0.) @@ -45,16 +50,16 @@ class SimpleRNNCell(LayerMixin): self.parameters = dict() std = 1.0 / math.sqrt(hidden_size) self.weight_ih = np.random.uniform(-std, std, ( - hidden_size, input_size)).astype('float64') + hidden_size, input_size)).astype(dtype) self.weight_hh = np.random.uniform(-std, std, ( - hidden_size, hidden_size)).astype('float64') + hidden_size, hidden_size)).astype(dtype) self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh if bias: self.bias_ih = np.random.uniform(-std, std, - (hidden_size, )).astype('float64') + (hidden_size, )).astype(dtype) self.bias_hh = np.random.uniform(-std, std, - (hidden_size, )).astype('float64') + (hidden_size, )).astype(dtype) self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_hh'] = self.bias_hh else: @@ -80,23 +85,23 @@ class SimpleRNNCell(LayerMixin): class GRUCell(LayerMixin): - def __init__(self, input_size, hidden_size, bias=True): + def __init__(self, input_size, hidden_size, bias=True, dtype="float64"): self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.parameters = dict() std = 1.0 / math.sqrt(hidden_size) self.weight_ih = np.random.uniform(-std, std, ( - 3 * hidden_size, input_size)).astype('float64') + 3 * hidden_size, input_size)).astype(dtype) self.weight_hh = np.random.uniform(-std, std, ( - 3 * hidden_size, hidden_size)).astype('float64') + 3 * hidden_size, hidden_size)).astype(dtype) self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh if bias: - self.bias_ih = np.random.uniform(-std, std, ( - 3 * hidden_size)).astype('float64') - self.bias_hh = np.random.uniform(-std, std, ( - 3 * hidden_size)).astype('float64') + self.bias_ih = np.random.uniform(-std, std, + (3 * hidden_size)).astype(dtype) + self.bias_hh = np.random.uniform(-std, std, + (3 * hidden_size)).astype(dtype) self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_hh'] = self.bias_hh else: @@ -128,23 +133,23 @@ class GRUCell(LayerMixin): class LSTMCell(LayerMixin): - def __init__(self, input_size, hidden_size, bias=True): + def __init__(self, input_size, hidden_size, bias=True, dtype="float64"): self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.parameters = dict() std = 1.0 / math.sqrt(hidden_size) self.weight_ih = np.random.uniform(-std, std, ( - 4 * hidden_size, input_size)).astype('float64') + 4 * hidden_size, input_size)).astype(dtype) self.weight_hh = np.random.uniform(-std, std, ( - 4 * hidden_size, hidden_size)).astype('float64') + 4 * hidden_size, hidden_size)).astype(dtype) self.parameters['weight_ih'] = self.weight_ih self.parameters['weight_hh'] = self.weight_hh if bias: - self.bias_ih = np.random.uniform(-std, std, ( - 4 * hidden_size)).astype('float64') - self.bias_hh = np.random.uniform(-std, std, ( - 4 * hidden_size)).astype('float64') + self.bias_ih = np.random.uniform(-std, std, + (4 * hidden_size)).astype(dtype) + self.bias_hh = np.random.uniform(-std, std, + (4 * hidden_size)).astype(dtype) self.parameters['bias_ih'] = self.bias_ih self.parameters['bias_hh'] = self.bias_hh else: @@ -403,28 +408,36 @@ class SimpleRNN(RNNMixin): input_size, hidden_size, num_layers=1, - nonlinearity="tanh", + nonlinearity="RNN_TANH", direction="forward", dropout=0., - time_major=False): + time_major=False, + dtype="float64"): super(SimpleRNN, self).__init__() if direction in ["forward", "backward"]: is_reverse = direction == "backward" - cell = SimpleRNNCell(input_size, hidden_size, nonlinearity) + cell = SimpleRNNCell( + input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) for i in range(1, num_layers): - cell = SimpleRNNCell(hidden_size, hidden_size, nonlinearity) + cell = SimpleRNNCell( + hidden_size, + hidden_size, + nonlinearity=nonlinearity, + dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) elif direction == "bidirectional": - cell_fw = SimpleRNNCell(input_size, hidden_size, nonlinearity) - cell_bw = SimpleRNNCell(input_size, hidden_size, nonlinearity) + cell_fw = SimpleRNNCell( + input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype) + cell_bw = SimpleRNNCell( + input_size, hidden_size, nonlinearity=nonlinearity, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) for i in range(1, num_layers): - cell_fw = SimpleRNNCell(2 * hidden_size, hidden_size, - nonlinearity) - cell_bw = SimpleRNNCell(2 * hidden_size, hidden_size, - nonlinearity) + cell_fw = SimpleRNNCell( + 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype) + cell_bw = SimpleRNNCell( + 2 * hidden_size, hidden_size, nonlinearity, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: raise ValueError( @@ -447,23 +460,24 @@ class LSTM(RNNMixin): num_layers=1, direction="forward", dropout=0., - time_major=False): + time_major=False, + dtype="float64"): super(LSTM, self).__init__() if direction in ["forward", "backward"]: is_reverse = direction == "backward" - cell = LSTMCell(input_size, hidden_size) + cell = LSTMCell(input_size, hidden_size, dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) for i in range(1, num_layers): - cell = LSTMCell(hidden_size, hidden_size) + cell = LSTMCell(hidden_size, hidden_size, dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) elif direction == "bidirectional": - cell_fw = LSTMCell(input_size, hidden_size) - cell_bw = LSTMCell(input_size, hidden_size) + cell_fw = LSTMCell(input_size, hidden_size, dtype=dtype) + cell_bw = LSTMCell(input_size, hidden_size, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) for i in range(1, num_layers): - cell_fw = LSTMCell(2 * hidden_size, hidden_size) - cell_bw = LSTMCell(2 * hidden_size, hidden_size) + cell_fw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype) + cell_bw = LSTMCell(2 * hidden_size, hidden_size, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: raise ValueError( @@ -486,23 +500,24 @@ class GRU(RNNMixin): num_layers=1, direction="forward", dropout=0., - time_major=False): + time_major=False, + dtype="float64"): super(GRU, self).__init__() if direction in ["forward", "backward"]: is_reverse = direction == "backward" - cell = GRUCell(input_size, hidden_size) + cell = GRUCell(input_size, hidden_size, dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) for i in range(1, num_layers): - cell = GRUCell(hidden_size, hidden_size) + cell = GRUCell(hidden_size, hidden_size, dtype=dtype) self.append(RNN(cell, is_reverse, time_major)) elif direction == "bidirectional": - cell_fw = GRUCell(input_size, hidden_size) - cell_bw = GRUCell(input_size, hidden_size) + cell_fw = GRUCell(input_size, hidden_size, dtype=dtype) + cell_bw = GRUCell(input_size, hidden_size, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) for i in range(1, num_layers): - cell_fw = GRUCell(2 * hidden_size, hidden_size) - cell_bw = GRUCell(2 * hidden_size, hidden_size) + cell_fw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype) + cell_bw = GRUCell(2 * hidden_size, hidden_size, dtype=dtype) self.append(BiRNN(cell_fw, cell_bw, time_major)) else: raise ValueError( diff --git a/python/paddle/fluid/tests/unittests/test_gru_rnn_op.py b/python/paddle/fluid/tests/unittests/test_gru_rnn_op.py new file mode 100644 index 0000000000..eb1fed81cb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gru_rnn_op.py @@ -0,0 +1,164 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import math + +from op_test import OpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import random +import sys +sys.path.append("./rnn") +from rnn_numpy import GRU +from convert import get_params_for_net +random.seed(2) +np.set_printoptions(threshold=np.inf) +paddle.enable_static() + + +class TestGRUOp(OpTest): + def get_weight_names(self): + weight_names = [] + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.weight_{}".format(i, j)) + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.bias_{}".format(i, j)) + return weight_names + + def setUp(self): + self.op_type = "rnn" + self.dtype = "float64" + self.sequence_length = np.array( + [12, 11, 10, 9, 8, 7, 6, 5], dtype=np.int32) + self.num_layers = 1 + self.is_bidirec = False + self.is_test = False + self.mode = "GRU" + self.dropout = 0. + seq_length = 12 + batch_size = 8 + input_size = 4 + self.hidden_size = 2 + self.set_attrs() + + self.direction_num = 2 if self.is_bidirec else 1 + direction = "bidirectional" if self.is_bidirec else "forward" + + input = np.random.uniform( + low=-0.1, high=0.1, + size=(seq_length, batch_size, input_size)).astype(self.dtype) + + if self.sequence_length is not None: + input[3][1:][:] = 0 + input[4][2:][:] = 0 + input[2][3:][:] = 0 + input[1][4:][:] = 0 + + rnn1 = GRU(input_size, + self.hidden_size, + num_layers=self.num_layers, + time_major=True, + direction=direction, + dropout=self.dropout, + dtype=self.dtype) + + flat_w = get_params_for_net(rnn1) + + output, last_hidden = rnn1(input, sequence_length=self.sequence_length) + + init_h = np.zeros((self.num_layers * self.direction_num, batch_size, + self.hidden_size)).astype(self.dtype) + + state_out = np.ndarray((300)).astype("uint8") + + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h)], + 'SequenceLength': self.sequence_length + } + if self.sequence_length is None: + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h)], + } + self.attrs = { + 'dropout_prob': self.dropout, + 'is_bidirec': self.is_bidirec, + 'input_size': input_size, + 'hidden_size': self.hidden_size, + 'num_layers': self.num_layers, + 'is_test': self.is_test, + 'mode': self.mode + } + self.outputs = { + 'Out': output, + 'State': [('last_hidden', last_hidden)], + 'Reserve': np.ndarray((400)).astype("uint8"), + 'DropoutState': state_out + } + + def set_attrs(self): + pass + + def test_output(self): + self.check_output(no_check_set=['Reserve', 'DropoutState']) + + def test_grad(self): + if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['Input', 'init_h'] + grad_check_list.extend(var_name_list) + self.check_grad(set(grad_check_list), ['Out', 'last_hidden']) + + +class TestGRUOp1(TestGRUOp): + def set_attrs(self): + self.sequence_length = None + + +class TestGRUOp2(TestGRUOp): + def set_attrs(self): + self.sequence_length = None + self.is_bidirec = True + + +class TestGRUOp3(TestGRUOp): + def set_attrs(self): + self.sequence_length = None + self.is_test = True + + +class TestGRUOp4(TestGRUOp): + def set_attrs(self): + self.sequence_length = None + self.is_bidirec = True + self.is_test = True + + +class TestGRUOpAvx(TestGRUOp): + def set_attrs(self): + self.dtype = "float32" + self.hidden_size = 8 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rnn_op.py b/python/paddle/fluid/tests/unittests/test_rnn_op.py new file mode 100644 index 0000000000..af3add34d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_rnn_op.py @@ -0,0 +1,159 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import paddle.fluid.core as core +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import random +import sys + +from op_test import OpTest +sys.path.append("./rnn") +from rnn_numpy import SimpleRNN, LSTM, GRU +from convert import get_params_for_net + +random.seed(2) +np.set_printoptions(threshold=np.inf) +paddle.enable_static() + + +class TestRNNOp(OpTest): + def get_weight_names(self): + weight_names = [] + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.weight_{}".format(i, j)) + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.bias_{}".format(i, j)) + return weight_names + + def setUp(self): + self.op_type = "rnn" + self.dtype = np.float64 + self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32) + self.num_layers = 1 + self.is_bidirec = False + self.mode = "LSTM" + self.is_test = False + self.set_attrs() + + self.direction_num = 2 if self.is_bidirec else 1 + direction = "bidirectional" if self.is_bidirec else "forward" + seq_length = 12 + batch_size = 5 + input_size = 3 + hidden_size = 2 + + input = np.random.uniform( + low=-0.1, high=0.1, + size=(seq_length, batch_size, input_size)).astype(self.dtype) + if self.sequence_length is not None: + input[11][1:][:] = 0 + input[10][2:][:] = 0 + input[9][3:][:] = 0 + input[8][4:][:] = 0 + + rnn1 = LSTM( + input_size, + hidden_size, + num_layers=self.num_layers, + time_major=True, + direction=direction) + + flat_w = get_params_for_net(rnn1) + output, (last_hidden, last_cell) = rnn1( + input, sequence_length=self.sequence_length) + + init_h = np.zeros((self.num_layers * self.direction_num, batch_size, + hidden_size)).astype(self.dtype) + init_c = np.zeros((self.num_layers * self.direction_num, batch_size, + hidden_size)).astype(self.dtype) + state_out = np.ndarray((300)).astype("uint8") + + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h), ('init_c', init_c)], + 'SequenceLength': self.sequence_length + } + if self.sequence_length is None: + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h), ('init_c', init_c)], + } + self.attrs = { + 'dropout_prob': 0.0, + 'is_bidirec': self.is_bidirec, + 'input_size': input_size, + 'hidden_size': hidden_size, + 'num_layers': self.num_layers, + 'mode': self.mode, + 'is_test': self.is_test + } + self.outputs = { + 'Out': output, + "State": [('last_hidden', last_hidden), ('last_cell', last_cell)], + 'Reserve': np.ndarray((400)).astype("uint8"), + 'DropoutState': state_out + } + + def test_output(self): + self.check_output(no_check_set=['Reserve', 'DropoutState']) + + def set_attrs(self): + pass + + def test_grad(self): + if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['Input', 'init_h', 'init_c'] + grad_check_list.extend(var_name_list) + self.check_grad( + set(grad_check_list), ['Out', 'last_hidden', 'last_cell']) + + +class TestRNNOp1(TestRNNOp): + def set_attrs(self): + self.sequence_length = None + + +class TestRNNOp2(TestRNNOp): + def set_attrs(self): + self.sequence_length = None + self.is_bidirec = True + + +class TestRNNOp3(TestRNNOp): + def set_attrs(self): + self.is_test = True + self.sequence_length = None + + +class TestRNNOp4(TestRNNOp): + def set_attrs(self): + self.is_test = True + self.sequence_length = None + self.is_bidirec = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_simple_rnn_op.py b/python/paddle/fluid/tests/unittests/test_simple_rnn_op.py new file mode 100644 index 0000000000..63688cbce2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_simple_rnn_op.py @@ -0,0 +1,162 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import math + +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import random +import sys +sys.path.append("./rnn") +from rnn_numpy import SimpleRNN +from convert import get_params_for_net + +random.seed(2) +np.set_printoptions(threshold=np.inf) +paddle.enable_static() + + +class TestSimpleRNNOp(OpTest): + def get_weight_names(self): + weight_names = [] + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.weight_{}".format(i, j)) + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.bias_{}".format(i, j)) + return weight_names + + def setUp(self): + self.op_type = "rnn" + self.dtype = np.float64 + self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32) + self.num_layers = 1 + self.is_bidirec = False + self.is_test = False + self.mode = "RNN_TANH" + self.dropout = 0. + self.set_attrs() + + self.direction_num = 2 if self.is_bidirec else 1 + direction = "bidirectional" if self.is_bidirec else "forward" + seq_length = 12 + batch_size = 5 + input_size = 3 + hidden_size = 2 + + input = np.random.uniform( + low=-0.1, high=0.1, + size=(seq_length, batch_size, input_size)).astype(self.dtype) + if self.sequence_length is not None: + input[11][1:][:] = 0 + input[10][2:][:] = 0 + input[9][3:][:] = 0 + input[8][4:][:] = 0 + + rnn1 = SimpleRNN( + input_size, + hidden_size, + num_layers=self.num_layers, + time_major=True, + direction=direction, + dropout=self.dropout, + nonlinearity=self.mode) + + flat_w = get_params_for_net(rnn1) + + output, last_hidden = rnn1(input, sequence_length=self.sequence_length) + + init_h = np.zeros((self.num_layers * self.direction_num, batch_size, + hidden_size)).astype(self.dtype) + + state_out = np.ndarray((300)).astype("uint8") + + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h)], + 'SequenceLength': self.sequence_length + } + if self.sequence_length is None: + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h)] + } + self.attrs = { + 'dropout_prob': self.dropout, + 'is_bidirec': self.is_bidirec, + 'input_size': input_size, + 'hidden_size': hidden_size, + 'num_layers': self.num_layers, + 'is_test': self.is_test, + 'mode': self.mode + } + self.outputs = { + 'Out': output, + 'State': [('last_hidden', last_hidden)], + 'Reserve': np.ndarray((400)).astype("uint8"), + 'DropoutState': state_out + } + + def set_attrs(self): + pass + + def test_output(self): + self.check_output(no_check_set=['Reserve', 'DropoutState']) + + def test_grad(self): + if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['Input', 'init_h'] + grad_check_list.extend(var_name_list) + self.check_grad(set(grad_check_list), ['Out', 'last_hidden']) + + +class TestSimpleRNNOp1(TestSimpleRNNOp): + def set_attrs(self): + self.sequence_length = None + + +class TestSimpleRNNOp2(TestSimpleRNNOp): + def set_attrs(self): + self.sequence_length = None + self.is_bidirec = True + + +class TestSimpleRNNOp3(TestSimpleRNNOp): + def set_attrs(self): + self.sequence_length = None + self.is_test = True + + +class TestSimpleRNNOp4(TestSimpleRNNOp): + def set_attrs(self): + self.sequence_length = None + self.is_bidirec = True + self.is_test = True + + +class TestSimpleRNNOp5(TestSimpleRNNOp): + def set_attrs(self): + self.mode = "RNN_RELU" + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py index e19641e710..15f28d94c7 100644 --- a/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/check_shape_white_list.py @@ -27,4 +27,5 @@ NEED_TO_FIX_OP_LIST = [ 'tree_conv', 'cvm', 'cudnn_lstm', + 'rnn', ] diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index afd3414943..24c89408b5 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -28,4 +28,5 @@ no_check_set_white_list = [ 'check_finite_and_unscale', 'update_loss_scaling', 'cudnn_lstm', + 'rnn', ] diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 47d62999c9..6076e9dc9f 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -43,7 +43,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'yolov3_loss', \ 'inverse', \ 'bilateral_slice',\ - 'cudnn_lstm' + 'cudnn_lstm', \ + 'rnn', \ ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 75817aa2dc..388dddf262 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -985,8 +985,7 @@ class RNNBase(LayerList): "direction should be forward, backward or bidirectional, " "received direction = {}".format(direction)) - self.could_use_cudnn = get_device().startswith( - "gpu:") and get_cudnn_version() + self.could_use_cudnn = True self.could_use_cudnn &= direction != "backward" self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * ( 2 if direction == "bidirectional" else 1) -- GitLab