From 522b3e411f33400ae2735e81c4bc65ca26438445 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 19:40:59 +0800 Subject: [PATCH] complete attention lstm op test --- .../tests/unittests/test_attention_lstm_op.py | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py index cd555a022bc..dea6ec76689 100644 --- a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -18,6 +18,7 @@ import unittest import numpy as np from op_test import OpTest from test_fusion_lstm_op import fc, ACTIVATION +from test_softmax_op import stable_softmax def attention_lstm( @@ -32,8 +33,56 @@ def attention_lstm( act_gate, act_cell, act_cand): - hidden - cell + + T = sum(lod[0]) + N = len(lod[0]) + M = x.shape[1] + D = b.shape[1] / 4 + assert T == x.shape[0] + assert len(fcws) == len(fcbs) + + hidden = [] + cell = [] + + start_offset = 0 + for bid in range(N): + seq_len = lod[0][bid] + xi = np.copy(x[start_offset:seq_len, :]).reshape(seq_len, M) + prev_cell = np.copy(c0[bid]).reshape([1, D]) + prev_hidden = np.copy(h0[bid]).reshape([1, D]) + for step in range(seq_len): + expanded_cell = np.repeat(prev_cell, seq_len, axis=0) + tmp = np.concatenate((xi, expanded_cell), axis=1) + assert tmp.shape[1] == M + D + for fcid in range(len(fcbs)): + tmp = fc(tmp, fcws[fcid], fcbs[fcid]) + tmp = ACTIVATION['relu'](tmp) + tmp = np.reshape(tmp, (1, seq_len)) + tmp = stable_softmax(tmp).reshape(seq_len, 1) + lstmx = xi * tmp # seq * M + lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M]) + lstmin = np.concatenate((prev_hidden, lstmx), axis=1) + lstmout = np.dot(lstmin, w).reshape([1, 4 * D]) + + g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1) + g_f = act_gate(g_f).reshape([1, D]) + g_i = act_gate(g_i).reshape([1, D]) + g_o = act_gate(g_o).reshape([1, D]) + cand = act_cand(cand).reshape([1, D]) + + cell_t = (prev_cell * g_f) + (g_i * cand) + hidden_t = g_o * act_cell(cell_t) + + hidden.append(hidden_t.flatten()) + cell.append(cell_t.flatten()) + + prev_cell = cell_t.reshape([1, D]) + prev_hidden = hidden_t.reshape([1, D]) + + start_offset += seq_len + + hidden = np.array(hidden).astype('float32').reshape([T, D]) + cell = np.array(cell).astype('float32').reshape([T, D]) return hidden, cell @@ -73,7 +122,7 @@ class TestAttentionLSTMOp(OpTest): b = np.random.normal(size=(1, self.D * 4)).astype('float32') h, c = attention_lstm(x, self.lod, h0, c0, [fcw1, fcw2], [fcb1, fcb2], - ACTIVATION[self.act_gate], + w, b, ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], ACTIVATION[self.act_cand]) -- GitLab