提交 fb9a6a2b 编写于 作者: X xuezhong

pass test for lstm op

test=develop
上级 1abb0d83
...@@ -37,6 +37,7 @@ class lstm { ...@@ -37,6 +37,7 @@ class lstm {
*value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate); *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
*value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
*state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg); *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
if (*cell_clip > 0.0) { if (*cell_clip > 0.0) {
if (*state < -1.0 * (*cell_clip)) { if (*state < -1.0 * (*cell_clip)) {
*state = -1.0 * (*cell_clip); *state = -1.0 * (*cell_clip);
...@@ -73,6 +74,7 @@ class lstm { ...@@ -73,6 +74,7 @@ class lstm {
active_gate); active_gate);
*state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig), *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
_mm256_mul_ps(*prev_state, *value_fg)); _mm256_mul_ps(*prev_state, *value_fg));
if (*cell_clip > 0.0f) { if (*cell_clip > 0.0f) {
__m256 min = _mm256_set1_ps(0.0f - *cell_clip); __m256 min = _mm256_set1_ps(0.0f - *cell_clip);
__m256 max = _mm256_set1_ps(*cell_clip); __m256 max = _mm256_set1_ps(*cell_clip);
...@@ -114,7 +116,12 @@ class lstm { ...@@ -114,7 +116,12 @@ class lstm {
activation((*output_grad) * (*value_og), *state_atv, active_state) + activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO); (*grad_og) * (*checkO);
} }
} else {
*state_grad +=
activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO);
} }
*grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node); *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
*grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate); *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
*grad_fg = *grad_fg =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册