From fb9a6a2bc6cbc88893544198ca1d9242523e3a06 Mon Sep 17 00:00:00 2001 From: xuezhong Date: Mon, 11 Feb 2019 10:17:02 +0000 Subject: [PATCH] pass test for lstm op test=develop --- paddle/fluid/operators/math/detail/lstm_kernel.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddle/fluid/operators/math/detail/lstm_kernel.h b/paddle/fluid/operators/math/detail/lstm_kernel.h index e1be0071f29..8149686c97a 100644 --- a/paddle/fluid/operators/math/detail/lstm_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_kernel.h @@ -37,6 +37,7 @@ class lstm { *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate); *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg); + if (*cell_clip > 0.0) { if (*state < -1.0 * (*cell_clip)) { *state = -1.0 * (*cell_clip); @@ -73,6 +74,7 @@ class lstm { active_gate); *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig), _mm256_mul_ps(*prev_state, *value_fg)); + if (*cell_clip > 0.0f) { __m256 min = _mm256_set1_ps(0.0f - *cell_clip); __m256 max = _mm256_set1_ps(*cell_clip); @@ -114,7 +116,12 @@ class lstm { activation((*output_grad) * (*value_og), *state_atv, active_state) + (*grad_og) * (*checkO); } + } else { + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); } + *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node); *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate); *grad_fg = -- GitLab