test_lstm_op.py 5.1 KB
Newer Older
1 2 3 4
import unittest
import numpy as np
from op_test import OpTest

5 6 7 8
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0

9 10 11 12 13 14

def identity(x):
    return x


def sigmoid(x):
15 16 17 18
    y = np.copy(x)
    y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
    y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
    return 1. / (1. + np.exp(-y))
19 20 21


def tanh(x):
22 23 24
    y = -2. * x
    y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
    return (2. / (1. + np.exp(y))) - 1.
25 26 27 28 29 30


def relu(x):
    return np.maximum(x, 0)


D
dangqingqing 已提交
31 32 33 34 35 36 37 38
ACTVATION = {
    'identity': identity,
    'sigmoid': sigmoid,
    'tanh': tanh,
    'relu': relu
}


39 40 41 42 43 44 45 46 47
def lstm(
        input,  # T x 4D
        lod,  # 1 x N
        h0=None,  # N x D
        c0=None,  # N x D
        w_h=None,  # D x 4D
        w_b=None,  # 1 x 4D
        w_c=None,  # 1 x 3D
        is_reverse=False,
D
dangqingqing 已提交
48 49 50 51
        act_gate=None,
        act_cell=None,
        act_cand=None):
    def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
52 53 54
        g = np.dot(h_pre, w_h)  # 1 x 4D
        g = g + x
        g = np.reshape(g, (1, g.size))
55
        c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
56
        if w_c is None:
D
dangqingqing 已提交
57 58
            g_i = act_gate(g_i)  # 1 x D
            g_f = act_gate(g_f)  # 1 x D
59 60
        else:
            w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
61 62 63
            g_i = act_gate(g_i + w_ic * c_pre)  # 1 x D
            g_f = act_gate(g_f + w_fc * c_pre)  # 1 x D
        c = g_f * c_pre + g_i * act_cand(c_tmp)  # 1 x D
64 65

        if w_c is None:
D
dangqingqing 已提交
66
            g_o = act_gate(g_o)  # 1 x D
67 68
        else:
            _, _, w_oc = np.split(w_c, 3, axis=1)
D
dangqingqing 已提交
69 70 71
            g_o = act_gate(g_o + w_oc * c)  # 1 x D
        h = g_o * act_cell(c)
        bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
72
        return h, c, bg
73

D
dangqingqing 已提交
74 75 76 77 78 79 80
    def _reverse(x, lod):
        y = np.zeros_like(x)
        for i in range(len(lod) - 1):
            b, e = lod[i], lod[i + 1]
            y[b:e, :] = np.flip(x[b:e, :], 0)
        return y

81 82 83 84
    offset = lod[0]
    batch_size = len(offset) - 1
    hidden = []
    cell = []
85
    gate = []
D
dangqingqing 已提交
86
    input = _reverse(input, offset) if is_reverse else input
87 88 89 90 91 92 93
    if w_b is not None:
        input = input + np.tile(w_b, (offset[-1], 1))
    for i in range(batch_size):
        # compute one sequence
        seq_len = offset[i + 1] - offset[i]
        x = input[offset[i]:offset[i + 1], :]
        h_pre = h0[i]  # 1 x D
94
        c_pre = c0[i]  # 1 x D
95 96
        for j in range(seq_len):
            # compute one step
D
dangqingqing 已提交
97 98
            h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
                                        act_cell, act_cand)
99 100
            hidden.append(h_pre.flatten())
            cell.append(c_pre.flatten())
101
            gate.append(g_pre.flatten())
102 103 104

    hidden = np.array(hidden).astype("float64")
    cell = np.array(cell).astype("float64")
105
    gate = np.array(gate).astype("float64")
D
dangqingqing 已提交
106 107 108 109

    hidden = _reverse(hidden, offset) if is_reverse else hidden
    cell = _reverse(cell, offset) if is_reverse else cell

110
    assert gate.shape == input.shape
111 112
    assert hidden.shape == (input.shape[0], input.shape[1] / 4)
    assert cell.shape == (input.shape[0], input.shape[1] / 4)
113
    return hidden, cell, gate
114 115


D
dangqingqing 已提交
116
class TestLstmOp(OpTest):
117
    def set_data(self):
D
dangqingqing 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        self.lod = [[0, 2, 6, 9]]
        self.D = 64
        self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]

        self.act_gate = "sigmoid"
        self.act_cell = "tanh"
        self.act_cand = "tanh"

        self.is_reverse = False

    def setUp(self):
        self.set_data()
        self.op_type = "lstm"

        T = self.lod[0][-1]
        N = len(self.lod[0]) - 1

        x = np.random.normal(size=(T, 4 * self.D)).astype("float64")
        h0 = np.zeros((N, self.D)).astype("float64")
        c0 = np.zeros((N, self.D)).astype("float64")
        w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64")
        b = np.random.normal(size=(1, 7 * self.D)).astype("float64")

        w_b = b[:, 0:4 * self.D]
        w_c = b[:, 4 * self.D:]
        h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
                       ACTVATION[self.act_gate], ACTVATION[self.act_cell],
                       ACTVATION[self.act_cand])
146 147

        g_sort = np.zeros_like(x)
D
dangqingqing 已提交
148 149
        for i, j in enumerate(self.sort_idx):
            g_sort[i, :] = g[j, :]
150 151

        self.inputs = {
D
dangqingqing 已提交
152
            'Input': (x, self.lod),
153 154 155 156 157 158
            'H0': h0,
            'C0': c0,
            'Weight': w,
            'Bias': b
        }
        self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort}
159 160
        self.attrs = {
            'usePeepholes': True,
D
dangqingqing 已提交
161 162 163 164
            'isReverse': self.is_reverse,
            'gateActivation': 'sigmoid',
            'cellActivation': 'tanh',
            'candidateActivation': 'tanh'
165 166 167 168 169 170
        }

    def test_check_output(self):
        self.check_output()


D
dangqingqing 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183
class TestLstmOpRerverse(TestLstmOp):
    def set_data(self):
        self.lod = [[0, 2, 6, 9]]
        self.D = 64
        self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]

        self.act_gate = "sigmoid"
        self.act_cell = "tanh"
        self.act_cand = "tanh"

        self.is_reverse = True


184 185
if __name__ == "__main__":
    unittest.main()